File size: 3,731 Bytes
4e75170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
src/continual/registry.py — Continual-learning task registry

Persists a JSON list of registered generator classes (tasks) to disk so
training scripts and the API can stay in sync.

Schema (per entry):
  {
    "label":       "sora",           # generator label string
    "sample_path": "data/sora/...",  # representative sample path
    "registered_at": "2025-01-01T00:00:00",
    "weight_path": null              # filled after EWC checkpoint is saved
  }

Usage:
    registry = TaskRegistry()
    registry.register("sora", "data/sora/frame_001.jpg")
    tasks = registry.all_tasks()
"""
from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

_DEFAULT_REGISTRY_PATH = Path("data/continual_registry.json")


class TaskRegistry:
    """
    Simple JSON-backed task registry for continual-learning experiments.

    Parameters
    ----------
    path:
        Path to the JSON registry file. Created on first write.
    """

    def __init__(self, path: Path | str = _DEFAULT_REGISTRY_PATH) -> None:
        self._path = Path(path)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def register(
        self,
        label: str,
        sample_path: str,
        weight_path: str | None = None,
    ) -> dict[str, Any]:
        """
        Register a new generator-class task.

        Returns the new registry entry dict.
        Raises ValueError if *label* already registered.
        """
        tasks = self._load()
        existing = [t for t in tasks if t["label"] == label]
        if existing:
            raise ValueError(
                f"Task '{label}' already registered. "
                f"Use update() to modify it."
            )

        entry: dict[str, Any] = {
            "label": label,
            "sample_path": sample_path,
            "registered_at": datetime.now(timezone.utc).isoformat(),
            "weight_path": weight_path,
        }
        tasks.append(entry)
        self._save(tasks)
        logger.info("Registered continual task: %s", label)
        return entry

    def update(self, label: str, **fields: Any) -> dict[str, Any]:
        """Update fields on an existing task entry."""
        tasks = self._load()
        for task in tasks:
            if task["label"] == label:
                task.update(fields)
                self._save(tasks)
                return task
        raise KeyError(f"Task '{label}' not found in registry.")

    def all_tasks(self) -> list[dict[str, Any]]:
        """Return all registered tasks."""
        return self._load()

    def get(self, label: str) -> dict[str, Any] | None:
        """Return a single task by label, or None."""
        for task in self._load():
            if task["label"] == label:
                return task
        return None

    # ------------------------------------------------------------------
    # Internal
    # ------------------------------------------------------------------

    def _load(self) -> list[dict[str, Any]]:
        if not self._path.exists():
            return []
        try:
            return json.loads(self._path.read_text(encoding="utf-8"))
        except json.JSONDecodeError:
            logger.warning("Registry file corrupt; starting fresh.")
            return []

    def _save(self, tasks: list[dict[str, Any]]) -> None:
        self._path.parent.mkdir(parents=True, exist_ok=True)
        self._path.write_text(
            json.dumps(tasks, indent=2, default=str),
            encoding="utf-8",
        )