File size: 11,784 Bytes
387e567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cc2b8
 
 
387e567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cc2b8
 
 
 
387e567
 
 
30cc2b8
 
 
387e567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cc2b8
387e567
 
 
 
 
 
30cc2b8
 
387e567
 
 
 
 
 
30cc2b8
 
 
387e567
 
 
 
 
30cc2b8
 
 
 
387e567
 
 
 
30cc2b8
387e567
30cc2b8
387e567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cc2b8
 
 
387e567
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
"""Model registry for checkpoint discovery and management.

Provides a unified interface for finding, loading, and comparing model
checkpoints across local directories and remote sources.

Usage:
    from landmarkdiff.model_registry import ModelRegistry

    registry = ModelRegistry("checkpoints/")

    # Discover all checkpoints
    models = registry.list_models()

    # Get best checkpoint by metric
    best = registry.get_best("loss")

    # Load a specific checkpoint
    state = registry.load("checkpoint-5000")

    # Compare multiple checkpoints
    comparison = registry.compare(["checkpoint-1000", "checkpoint-5000"])
"""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import torch


@dataclass
class ModelEntry:
    """Metadata for a registered model checkpoint."""

    name: str
    path: Path
    step: int = 0
    phase: str = ""
    metrics: dict[str, float] = field(default_factory=dict)
    size_mb: float = 0.0
    has_ema: bool = False
    has_training_state: bool = False

    @property
    def inference_path(self) -> Path | None:
        """Path to inference-ready weights (EMA preferred)."""
        ema_dir = self.path / "controlnet_ema"
        if ema_dir.exists():
            return ema_dir
        # Fallback to training state
        state_path = self.path / "training_state.pt"
        if state_path.exists():
            return state_path
        return None


class ModelRegistry:
    """Central registry for discovering and managing model checkpoints.

    Args:
        checkpoint_dirs: One or more directories to scan for checkpoints.
        scan_on_init: Whether to scan directories immediately on creation.
    """

    def __init__(
        self,
        *checkpoint_dirs: str | Path,
        scan_on_init: bool = True,
    ) -> None:
        self.checkpoint_dirs = [Path(d) for d in checkpoint_dirs]
        self._models: dict[str, ModelEntry] = {}

        if scan_on_init:
            self.scan()

    def scan(self) -> int:
        """Scan checkpoint directories and register all found models.

        Returns:
            Number of models found.
        """
        self._models.clear()
        for base_dir in self.checkpoint_dirs:
            if not base_dir.exists():
                continue
            self._scan_directory(base_dir)
        return len(self._models)

    def _scan_directory(self, base_dir: Path) -> None:
        """Scan a single directory for checkpoint subdirectories."""
        # Look for checkpoint-* directories
        for ckpt_dir in sorted(base_dir.glob("checkpoint-*")):
            if not ckpt_dir.is_dir():
                continue
            entry = self._load_entry(ckpt_dir)
            if entry is not None:
                self._models[entry.name] = entry

        # Also check for "final" and "best" directories/symlinks
        for special in ["final", "best", "latest"]:
            special_dir = base_dir / special
            if special_dir.exists() and special_dir.is_dir():
                entry = self._load_entry(special_dir)
                if entry is not None:
                    entry.name = f"{base_dir.name}/{special}"
                    self._models[entry.name] = entry

    def _load_entry(self, ckpt_dir: Path) -> ModelEntry | None:
        """Load metadata for a single checkpoint directory."""
        has_training = (ckpt_dir / "training_state.pt").exists()
        has_ema = (ckpt_dir / "controlnet_ema").exists()

        if not has_training and not has_ema:
            return None

        # Try to load metadata.json (from CheckpointManager)
        meta_path = ckpt_dir / "metadata.json"
        if meta_path.exists():
            with open(meta_path) as f:
                meta = json.load(f)
            return ModelEntry(
                name=ckpt_dir.name,
                path=ckpt_dir,
                step=meta.get("step", 0),
                phase=meta.get("phase", ""),
                metrics=meta.get("metrics", {}),
                size_mb=meta.get("size_mb", 0.0),
                has_ema=has_ema,
                has_training_state=has_training,
            )

        # Fallback: extract step from directory name
        step = 0
        parts = ckpt_dir.name.split("-")
        if len(parts) >= 2 and parts[-1].isdigit():
            step = int(parts[-1])

        # Compute size
        size_mb = sum(
            f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
        ) / (1024 * 1024)

        return ModelEntry(
            name=ckpt_dir.name,
            path=ckpt_dir,
            step=step,
            size_mb=round(size_mb, 1),
            has_ema=has_ema,
            has_training_state=has_training,
        )

    # ------------------------------------------------------------------
    # Queries
    # ------------------------------------------------------------------

    def list_models(self, sort_by: str = "step") -> list[ModelEntry]:
        """List all registered models.

        Args:
            sort_by: Sort key — "step", "name", or a metric name.

        Returns:
            Sorted list of ModelEntry objects.
        """
        models = list(self._models.values())
        if sort_by == "step":
            models.sort(key=lambda m: m.step)
        elif sort_by == "name":
            models.sort(key=lambda m: m.name)
        else:
            # Sort by metric value
            models.sort(
                key=lambda m: m.metrics.get(sort_by, float("inf")),
            )
        return models

    def get(self, name: str) -> ModelEntry | None:
        """Get a model entry by name."""
        return self._models.get(name)

    def get_best(
        self,
        metric: str = "loss",
        lower_is_better: bool = True,
    ) -> ModelEntry | None:
        """Get the best model by a specific metric.

        Args:
            metric: Metric name to rank by.
            lower_is_better: If True, lower values are better.

        Returns:
            Best ModelEntry, or None if no models have the metric.
        """
        candidates = [
            m for m in self._models.values()
            if metric in m.metrics
        ]
        if not candidates:
            return None

        return min(candidates, key=lambda m: m.metrics[metric])  \
            if lower_is_better else \
            max(candidates, key=lambda m: m.metrics[metric])

    def get_by_step(self, step: int) -> ModelEntry | None:
        """Get a model by its training step."""
        for model in self._models.values():
            if model.step == step:
                return model
        return None

    # ------------------------------------------------------------------
    # Loading
    # ------------------------------------------------------------------

    def load(
        self,
        name: str,
        map_location: str = "cpu",
    ) -> dict[str, Any]:
        """Load training state from a checkpoint.

        Args:
            name: Checkpoint name (e.g. "checkpoint-5000").
            map_location: Device to load tensors to.

        Returns:
            State dict containing controlnet, ema_controlnet, optimizer, etc.

        Raises:
            KeyError: If checkpoint not found.
            FileNotFoundError: If training_state.pt missing.
        """
        entry = self._models.get(name)
        if entry is None:
            raise KeyError(f"Checkpoint '{name}' not found in registry")

        state_path = entry.path / "training_state.pt"
        if not state_path.exists():
            raise FileNotFoundError(f"No training_state.pt in {entry.path}")

        return torch.load(state_path, map_location=map_location, weights_only=True)

    def load_controlnet(
        self,
        name: str,
        use_ema: bool = True,
        torch_dtype: torch.dtype | None = None,
    ) -> Any:
        """Load a ControlNet model from checkpoint.

        Args:
            name: Checkpoint name.
            use_ema: If True, load EMA weights (preferred for inference).
            torch_dtype: Weight dtype (e.g. torch.float16). Defaults to
                float16 on CUDA, float32 on CPU.

        Returns:
            ControlNetModel instance.
        """
        from diffusers import ControlNetModel

        if torch_dtype is None:
            torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        entry = self._models.get(name)
        if entry is None:
            raise KeyError(f"Checkpoint '{name}' not found in registry")

        if use_ema and entry.has_ema:
            return ControlNetModel.from_pretrained(
                str(entry.path / "controlnet_ema"),
                torch_dtype=torch_dtype,
            )

        # Fallback: load from training state
        state = self.load(name)
        model = ControlNetModel.from_pretrained(
            "CrucibleAI/ControlNetMediaPipeFace",
            subfolder="diffusion_sd15",
            torch_dtype=torch_dtype,
        )
        key = "ema_controlnet" if use_ema else "controlnet"
        model.load_state_dict(state[key])
        return model

    # ------------------------------------------------------------------
    # Comparison
    # ------------------------------------------------------------------

    def compare(
        self,
        names: list[str],
        metrics: list[str] | None = None,
    ) -> dict[str, Any]:
        """Compare multiple checkpoints side-by-side.

        Args:
            names: List of checkpoint names to compare.
            metrics: Specific metrics to include. None = all available.

        Returns:
            Dict with comparison table data.
        """
        entries = []
        for name in names:
            entry = self._models.get(name)
            if entry is not None:
                entries.append(entry)

        if not entries:
            return {"error": "No valid checkpoints found"}

        # Collect all available metrics
        if metrics is None:
            all_metrics: set[str] = set()
            for e in entries:
                all_metrics.update(e.metrics.keys())
            metrics = sorted(all_metrics)

        rows = []
        for e in entries:
            row = {
                "name": e.name,
                "step": e.step,
                "phase": e.phase,
                "size_mb": e.size_mb,
            }
            for m in metrics:
                row[m] = e.metrics.get(m)
            rows.append(row)

        return {
            "metrics": metrics,
            "rows": rows,
            "count": len(rows),
        }

    # ------------------------------------------------------------------
    # Summary
    # ------------------------------------------------------------------

    def summary(self) -> str:
        """Return a human-readable summary."""
        models = self.list_models()
        if not models:
            return "No models registered."

        total_size = sum(m.size_mb for m in models)
        lines = [
            f"Model Registry: {len(models)} checkpoints ({total_size:.0f} MB)",
            f"  Steps: {models[0].step}{models[-1].step}",
        ]

        # Show metrics ranges
        all_metrics: set[str] = set()
        for m in models:
            all_metrics.update(m.metrics.keys())

        for metric in sorted(all_metrics):
            values = [m.metrics[metric] for m in models if metric in m.metrics]
            if values:
                lines.append(
                    f"  {metric}: {min(values):.4f}{max(values):.4f}"
                )

        return "\n".join(lines)

    def __len__(self) -> int:
        return len(self._models)

    def __contains__(self, name: str) -> bool:
        return name in self._models