dreamlessx commited on
Commit
9475836
·
verified ·
1 Parent(s): a1b1648

Upload landmarkdiff/checkpoint_manager.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/checkpoint_manager.py +365 -0
landmarkdiff/checkpoint_manager.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint management with metadata tracking, best-model selection, and pruning.
2
+
3
+ Provides a central manager for training checkpoints that:
4
+ - Tracks per-checkpoint metadata (step, metrics, timestamps)
5
+ - Maintains symlinks to best/latest checkpoints
6
+ - Prunes old checkpoints to save disk space
7
+ - Supports multiple ranking metrics (loss, FID, SSIM, etc.)
8
+
9
+ Usage:
10
+ manager = CheckpointManager(
11
+ output_dir="checkpoints/phaseA",
12
+ keep_best=3,
13
+ keep_latest=5,
14
+ metric="loss",
15
+ lower_is_better=True,
16
+ )
17
+
18
+ # During training loop:
19
+ manager.save(
20
+ step=1000,
21
+ controlnet=controlnet,
22
+ ema_controlnet=ema_controlnet,
23
+ optimizer=optimizer,
24
+ scheduler=scheduler,
25
+ metrics={"loss": 0.0123, "val_ssim": 0.87},
26
+ )
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import shutil
33
+ import time
34
+ from dataclasses import asdict, dataclass, field
35
+ from pathlib import Path
36
+ from typing import Any
37
+
38
+ import torch
39
+
40
+
41
+ @dataclass
42
+ class CheckpointMetadata:
43
+ """Metadata for a single checkpoint."""
44
+
45
+ step: int
46
+ timestamp: float
47
+ metrics: dict[str, float] = field(default_factory=dict)
48
+ epoch: int | None = None
49
+ phase: str = ""
50
+ is_best: bool = False
51
+ size_mb: float = 0.0
52
+
53
+ def to_dict(self) -> dict[str, Any]:
54
+ return asdict(self)
55
+
56
+ @classmethod
57
+ def from_dict(cls, d: dict[str, Any]) -> CheckpointMetadata:
58
+ return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
59
+
60
+
61
+ class CheckpointManager:
62
+ """Manages training checkpoints with pruning and best-model tracking.
63
+
64
+ Args:
65
+ output_dir: Base directory for checkpoints.
66
+ keep_best: Number of best checkpoints to retain.
67
+ keep_latest: Number of most recent checkpoints to retain.
68
+ metric: Metric name used to determine "best" checkpoint.
69
+ lower_is_better: If True, lower metric values are better (e.g. loss, FID).
70
+ prefix: Checkpoint directory prefix (default: "checkpoint").
71
+ """
72
+
73
+ INDEX_FILE = "checkpoint_index.json"
74
+
75
+ def __init__(
76
+ self,
77
+ output_dir: str | Path,
78
+ keep_best: int = 3,
79
+ keep_latest: int = 5,
80
+ metric: str = "loss",
81
+ lower_is_better: bool = True,
82
+ prefix: str = "checkpoint",
83
+ ) -> None:
84
+ self.output_dir = Path(output_dir)
85
+ self.output_dir.mkdir(parents=True, exist_ok=True)
86
+ self.keep_best = keep_best
87
+ self.keep_latest = keep_latest
88
+ self.metric = metric
89
+ self.lower_is_better = lower_is_better
90
+ self.prefix = prefix
91
+
92
+ self._index: dict[str, Any] = {"checkpoints": {}}
93
+ self._load_index()
94
+
95
+ # ------------------------------------------------------------------
96
+ # Index persistence
97
+ # ------------------------------------------------------------------
98
+
99
+ def _index_path(self) -> Path:
100
+ return self.output_dir / self.INDEX_FILE
101
+
102
+ def _load_index(self) -> None:
103
+ path = self._index_path()
104
+ if path.exists():
105
+ with open(path) as f:
106
+ self._index = json.load(f)
107
+ if "checkpoints" not in self._index:
108
+ self._index["checkpoints"] = {}
109
+
110
+ def _save_index(self) -> None:
111
+ with open(self._index_path(), "w") as f:
112
+ json.dump(self._index, f, indent=2)
113
+
114
+ # ------------------------------------------------------------------
115
+ # Save checkpoint
116
+ # ------------------------------------------------------------------
117
+
118
+ def save(
119
+ self,
120
+ step: int,
121
+ controlnet: torch.nn.Module,
122
+ ema_controlnet: torch.nn.Module,
123
+ optimizer: torch.optim.Optimizer,
124
+ scheduler: Any = None,
125
+ metrics: dict[str, float] | None = None,
126
+ epoch: int | None = None,
127
+ phase: str = "",
128
+ extra_state: dict[str, Any] | None = None,
129
+ ) -> Path:
130
+ """Save a checkpoint with metadata.
131
+
132
+ Args:
133
+ step: Current training step.
134
+ controlnet: ControlNet model (or any nn.Module).
135
+ ema_controlnet: EMA copy of the model.
136
+ optimizer: Optimizer state.
137
+ scheduler: Optional LR scheduler.
138
+ metrics: Dict of metric values at this step.
139
+ epoch: Optional epoch number.
140
+ phase: Training phase label (e.g. "A", "B").
141
+ extra_state: Any additional state to save.
142
+
143
+ Returns:
144
+ Path to the saved checkpoint directory.
145
+ """
146
+ ckpt_name = f"{self.prefix}-{step}"
147
+ ckpt_dir = self.output_dir / ckpt_name
148
+ ckpt_dir.mkdir(exist_ok=True)
149
+
150
+ # Save EMA weights (used for inference)
151
+ if hasattr(ema_controlnet, "save_pretrained"):
152
+ ema_controlnet.save_pretrained(ckpt_dir / "controlnet_ema")
153
+
154
+ # Save training state for resume
155
+ state = {
156
+ "controlnet": _get_state_dict(controlnet),
157
+ "ema_controlnet": _get_state_dict(ema_controlnet),
158
+ "optimizer": optimizer.state_dict(),
159
+ "global_step": step,
160
+ }
161
+ if scheduler is not None:
162
+ state["scheduler"] = scheduler.state_dict()
163
+ if extra_state:
164
+ state.update(extra_state)
165
+
166
+ torch.save(state, ckpt_dir / "training_state.pt")
167
+
168
+ # Compute checkpoint size
169
+ size_mb = sum(
170
+ f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
171
+ ) / (1024 * 1024)
172
+
173
+ # Create metadata
174
+ meta = CheckpointMetadata(
175
+ step=step,
176
+ timestamp=time.time(),
177
+ metrics=metrics or {},
178
+ epoch=epoch,
179
+ phase=phase,
180
+ size_mb=round(size_mb, 1),
181
+ )
182
+
183
+ # Save metadata alongside checkpoint
184
+ with open(ckpt_dir / "metadata.json", "w") as f:
185
+ json.dump(meta.to_dict(), f, indent=2)
186
+
187
+ # Update index
188
+ self._index["checkpoints"][ckpt_name] = meta.to_dict()
189
+ self._update_best()
190
+ self._save_index()
191
+
192
+ # Update symlinks
193
+ self._update_symlinks()
194
+
195
+ # Prune old checkpoints
196
+ self._prune()
197
+
198
+ return ckpt_dir
199
+
200
+ # ------------------------------------------------------------------
201
+ # Best / latest tracking
202
+ # ------------------------------------------------------------------
203
+
204
+ def _update_best(self) -> None:
205
+ """Recompute which checkpoints are 'best'."""
206
+ entries = []
207
+ for name, meta in self._index["checkpoints"].items():
208
+ val = meta.get("metrics", {}).get(self.metric)
209
+ if val is not None:
210
+ entries.append((name, val, meta))
211
+
212
+ if not entries:
213
+ return
214
+
215
+ # Sort by metric
216
+ entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
217
+
218
+ # Mark best
219
+ best_names = {e[0] for e in entries[:self.keep_best]}
220
+ for name, meta in self._index["checkpoints"].items():
221
+ meta["is_best"] = name in best_names
222
+
223
+ def _update_symlinks(self) -> None:
224
+ """Update 'latest' and 'best' symlinks."""
225
+ checkpoints = self._sorted_by_step()
226
+ if not checkpoints:
227
+ return
228
+
229
+ # Latest symlink
230
+ latest_name = checkpoints[-1]
231
+ latest_link = self.output_dir / "latest"
232
+ _force_symlink(self.output_dir / latest_name, latest_link)
233
+
234
+ # Best symlink
235
+ best_name = self.get_best_checkpoint_name()
236
+ if best_name:
237
+ best_link = self.output_dir / "best"
238
+ _force_symlink(self.output_dir / best_name, best_link)
239
+
240
+ def get_best_checkpoint_name(self) -> str | None:
241
+ """Return the name of the best checkpoint by tracked metric."""
242
+ best = None
243
+ best_val = None
244
+ for name, meta in self._index["checkpoints"].items():
245
+ val = meta.get("metrics", {}).get(self.metric)
246
+ if val is None:
247
+ continue
248
+ if best_val is None:
249
+ best, best_val = name, val
250
+ elif self.lower_is_better and val < best_val:
251
+ best, best_val = name, val
252
+ elif not self.lower_is_better and val > best_val:
253
+ best, best_val = name, val
254
+ return best
255
+
256
+ def get_best_metric_value(self) -> float | None:
257
+ """Return the best value of the tracked metric."""
258
+ name = self.get_best_checkpoint_name()
259
+ if name is None:
260
+ return None
261
+ return self._index["checkpoints"][name]["metrics"].get(self.metric)
262
+
263
+ # ------------------------------------------------------------------
264
+ # Pruning
265
+ # ------------------------------------------------------------------
266
+
267
+ def _sorted_by_step(self) -> list[str]:
268
+ """Return checkpoint names sorted by step (ascending)."""
269
+ items = list(self._index["checkpoints"].items())
270
+ items.sort(key=lambda x: x[1].get("step", 0))
271
+ return [name for name, _ in items]
272
+
273
+ def _prune(self) -> None:
274
+ """Remove old checkpoints, keeping best N and latest M."""
275
+ all_names = self._sorted_by_step()
276
+ if len(all_names) <= self.keep_latest:
277
+ return
278
+
279
+ # Determine which to keep
280
+ keep = set()
281
+
282
+ # Keep latest
283
+ for name in all_names[-self.keep_latest:]:
284
+ keep.add(name)
285
+
286
+ # Keep best
287
+ for name, meta in self._index["checkpoints"].items():
288
+ if meta.get("is_best", False):
289
+ keep.add(name)
290
+
291
+ # Delete the rest
292
+ for name in all_names:
293
+ if name not in keep:
294
+ ckpt_dir = self.output_dir / name
295
+ if ckpt_dir.exists():
296
+ shutil.rmtree(ckpt_dir)
297
+ del self._index["checkpoints"][name]
298
+
299
+ self._save_index()
300
+
301
+ # ------------------------------------------------------------------
302
+ # Queries
303
+ # ------------------------------------------------------------------
304
+
305
+ def list_checkpoints(self) -> list[dict[str, Any]]:
306
+ """Return metadata for all tracked checkpoints, sorted by step."""
307
+ result = []
308
+ for name in self._sorted_by_step():
309
+ meta = self._index["checkpoints"][name]
310
+ result.append({"name": name, **meta})
311
+ return result
312
+
313
+ def get_checkpoint_path(self, name: str) -> Path:
314
+ """Return the filesystem path for a checkpoint by name."""
315
+ return self.output_dir / name
316
+
317
+ def get_latest_step(self) -> int:
318
+ """Return the step of the most recent checkpoint, or 0."""
319
+ names = self._sorted_by_step()
320
+ if not names:
321
+ return 0
322
+ return self._index["checkpoints"][names[-1]].get("step", 0)
323
+
324
+ def total_size_mb(self) -> float:
325
+ """Return total disk size of all tracked checkpoints."""
326
+ return sum(
327
+ meta.get("size_mb", 0.0)
328
+ for meta in self._index["checkpoints"].values()
329
+ )
330
+
331
+ def summary(self) -> str:
332
+ """Return a human-readable summary of checkpoint state."""
333
+ ckpts = self.list_checkpoints()
334
+ if not ckpts:
335
+ return "No checkpoints saved."
336
+
337
+ lines = [
338
+ f"Checkpoints: {len(ckpts)} saved ({self.total_size_mb():.0f} MB total)",
339
+ f"Latest: step {self.get_latest_step()}",
340
+ ]
341
+
342
+ best_name = self.get_best_checkpoint_name()
343
+ best_val = self.get_best_metric_value()
344
+ if best_name and best_val is not None:
345
+ lines.append(f"Best ({self.metric}): {best_val:.6f} @ {best_name}")
346
+
347
+ return "\n".join(lines)
348
+
349
+
350
+ # ------------------------------------------------------------------
351
+ # Helpers
352
+ # ------------------------------------------------------------------
353
+
354
+ def _get_state_dict(module: torch.nn.Module) -> dict:
355
+ """Extract state dict, handling DDP wrapper."""
356
+ if hasattr(module, "module"):
357
+ return module.module.state_dict()
358
+ return module.state_dict()
359
+
360
+
361
+ def _force_symlink(target: Path, link: Path) -> None:
362
+ """Create or replace a symlink."""
363
+ if link.is_symlink() or link.exists():
364
+ link.unlink()
365
+ link.symlink_to(target.name)