""" src/continual/registry.py — Continual-learning task registry Persists a JSON list of registered generator classes (tasks) to disk so training scripts and the API can stay in sync. Schema (per entry): { "label": "sora", # generator label string "sample_path": "data/sora/...", # representative sample path "registered_at": "2025-01-01T00:00:00", "weight_path": null # filled after EWC checkpoint is saved } Usage: registry = TaskRegistry() registry.register("sora", "data/sora/frame_001.jpg") tasks = registry.all_tasks() """ from __future__ import annotations import json import logging import os from datetime import datetime, timezone from pathlib import Path from typing import Any logger = logging.getLogger(__name__) _DEFAULT_REGISTRY_PATH = Path("data/continual_registry.json") class TaskRegistry: """ Simple JSON-backed task registry for continual-learning experiments. Parameters ---------- path: Path to the JSON registry file. Created on first write. """ def __init__(self, path: Path | str = _DEFAULT_REGISTRY_PATH) -> None: self._path = Path(path) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def register( self, label: str, sample_path: str, weight_path: str | None = None, ) -> dict[str, Any]: """ Register a new generator-class task. Returns the new registry entry dict. Raises ValueError if *label* already registered. """ tasks = self._load() existing = [t for t in tasks if t["label"] == label] if existing: raise ValueError( f"Task '{label}' already registered. " f"Use update() to modify it." ) entry: dict[str, Any] = { "label": label, "sample_path": sample_path, "registered_at": datetime.now(timezone.utc).isoformat(), "weight_path": weight_path, } tasks.append(entry) self._save(tasks) logger.info("Registered continual task: %s", label) return entry def update(self, label: str, **fields: Any) -> dict[str, Any]: """Update fields on an existing task entry.""" tasks = self._load() for task in tasks: if task["label"] == label: task.update(fields) self._save(tasks) return task raise KeyError(f"Task '{label}' not found in registry.") def all_tasks(self) -> list[dict[str, Any]]: """Return all registered tasks.""" return self._load() def get(self, label: str) -> dict[str, Any] | None: """Return a single task by label, or None.""" for task in self._load(): if task["label"] == label: return task return None # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _load(self) -> list[dict[str, Any]]: if not self._path.exists(): return [] try: return json.loads(self._path.read_text(encoding="utf-8")) except json.JSONDecodeError: logger.warning("Registry file corrupt; starting fresh.") return [] def _save(self, tasks: list[dict[str, Any]]) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) self._path.write_text( json.dumps(tasks, indent=2, default=str), encoding="utf-8", )