deepdetection / src /continual /registry.py
akagtag's picture
Initial commit
4e75170
"""
src/continual/registry.py — Continual-learning task registry
Persists a JSON list of registered generator classes (tasks) to disk so
training scripts and the API can stay in sync.
Schema (per entry):
{
"label": "sora", # generator label string
"sample_path": "data/sora/...", # representative sample path
"registered_at": "2025-01-01T00:00:00",
"weight_path": null # filled after EWC checkpoint is saved
}
Usage:
registry = TaskRegistry()
registry.register("sora", "data/sora/frame_001.jpg")
tasks = registry.all_tasks()
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_DEFAULT_REGISTRY_PATH = Path("data/continual_registry.json")
class TaskRegistry:
"""
Simple JSON-backed task registry for continual-learning experiments.
Parameters
----------
path:
Path to the JSON registry file. Created on first write.
"""
def __init__(self, path: Path | str = _DEFAULT_REGISTRY_PATH) -> None:
self._path = Path(path)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def register(
self,
label: str,
sample_path: str,
weight_path: str | None = None,
) -> dict[str, Any]:
"""
Register a new generator-class task.
Returns the new registry entry dict.
Raises ValueError if *label* already registered.
"""
tasks = self._load()
existing = [t for t in tasks if t["label"] == label]
if existing:
raise ValueError(
f"Task '{label}' already registered. "
f"Use update() to modify it."
)
entry: dict[str, Any] = {
"label": label,
"sample_path": sample_path,
"registered_at": datetime.now(timezone.utc).isoformat(),
"weight_path": weight_path,
}
tasks.append(entry)
self._save(tasks)
logger.info("Registered continual task: %s", label)
return entry
def update(self, label: str, **fields: Any) -> dict[str, Any]:
"""Update fields on an existing task entry."""
tasks = self._load()
for task in tasks:
if task["label"] == label:
task.update(fields)
self._save(tasks)
return task
raise KeyError(f"Task '{label}' not found in registry.")
def all_tasks(self) -> list[dict[str, Any]]:
"""Return all registered tasks."""
return self._load()
def get(self, label: str) -> dict[str, Any] | None:
"""Return a single task by label, or None."""
for task in self._load():
if task["label"] == label:
return task
return None
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _load(self) -> list[dict[str, Any]]:
if not self._path.exists():
return []
try:
return json.loads(self._path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
logger.warning("Registry file corrupt; starting fresh.")
return []
def _save(self, tasks: list[dict[str, Any]]) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._path.write_text(
json.dumps(tasks, indent=2, default=str),
encoding="utf-8",
)