Spaces:
Running
Running
| """ | |
| src/continual/registry.py — Continual-learning task registry | |
| Persists a JSON list of registered generator classes (tasks) to disk so | |
| training scripts and the API can stay in sync. | |
| Schema (per entry): | |
| { | |
| "label": "sora", # generator label string | |
| "sample_path": "data/sora/...", # representative sample path | |
| "registered_at": "2025-01-01T00:00:00", | |
| "weight_path": null # filled after EWC checkpoint is saved | |
| } | |
| Usage: | |
| registry = TaskRegistry() | |
| registry.register("sora", "data/sora/frame_001.jpg") | |
| tasks = registry.all_tasks() | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_REGISTRY_PATH = Path("data/continual_registry.json") | |
| class TaskRegistry: | |
| """ | |
| Simple JSON-backed task registry for continual-learning experiments. | |
| Parameters | |
| ---------- | |
| path: | |
| Path to the JSON registry file. Created on first write. | |
| """ | |
| def __init__(self, path: Path | str = _DEFAULT_REGISTRY_PATH) -> None: | |
| self._path = Path(path) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def register( | |
| self, | |
| label: str, | |
| sample_path: str, | |
| weight_path: str | None = None, | |
| ) -> dict[str, Any]: | |
| """ | |
| Register a new generator-class task. | |
| Returns the new registry entry dict. | |
| Raises ValueError if *label* already registered. | |
| """ | |
| tasks = self._load() | |
| existing = [t for t in tasks if t["label"] == label] | |
| if existing: | |
| raise ValueError( | |
| f"Task '{label}' already registered. " | |
| f"Use update() to modify it." | |
| ) | |
| entry: dict[str, Any] = { | |
| "label": label, | |
| "sample_path": sample_path, | |
| "registered_at": datetime.now(timezone.utc).isoformat(), | |
| "weight_path": weight_path, | |
| } | |
| tasks.append(entry) | |
| self._save(tasks) | |
| logger.info("Registered continual task: %s", label) | |
| return entry | |
| def update(self, label: str, **fields: Any) -> dict[str, Any]: | |
| """Update fields on an existing task entry.""" | |
| tasks = self._load() | |
| for task in tasks: | |
| if task["label"] == label: | |
| task.update(fields) | |
| self._save(tasks) | |
| return task | |
| raise KeyError(f"Task '{label}' not found in registry.") | |
| def all_tasks(self) -> list[dict[str, Any]]: | |
| """Return all registered tasks.""" | |
| return self._load() | |
| def get(self, label: str) -> dict[str, Any] | None: | |
| """Return a single task by label, or None.""" | |
| for task in self._load(): | |
| if task["label"] == label: | |
| return task | |
| return None | |
| # ------------------------------------------------------------------ | |
| # Internal | |
| # ------------------------------------------------------------------ | |
| def _load(self) -> list[dict[str, Any]]: | |
| if not self._path.exists(): | |
| return [] | |
| try: | |
| return json.loads(self._path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError: | |
| logger.warning("Registry file corrupt; starting fresh.") | |
| return [] | |
| def _save(self, tasks: list[dict[str, Any]]) -> None: | |
| self._path.parent.mkdir(parents=True, exist_ok=True) | |
| self._path.write_text( | |
| json.dumps(tasks, indent=2, default=str), | |
| encoding="utf-8", | |
| ) | |