Spaces:
Sleeping
Sleeping
File size: 3,731 Bytes
4e75170 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """
src/continual/registry.py — Continual-learning task registry
Persists a JSON list of registered generator classes (tasks) to disk so
training scripts and the API can stay in sync.
Schema (per entry):
{
"label": "sora", # generator label string
"sample_path": "data/sora/...", # representative sample path
"registered_at": "2025-01-01T00:00:00",
"weight_path": null # filled after EWC checkpoint is saved
}
Usage:
registry = TaskRegistry()
registry.register("sora", "data/sora/frame_001.jpg")
tasks = registry.all_tasks()
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_DEFAULT_REGISTRY_PATH = Path("data/continual_registry.json")
class TaskRegistry:
"""
Simple JSON-backed task registry for continual-learning experiments.
Parameters
----------
path:
Path to the JSON registry file. Created on first write.
"""
def __init__(self, path: Path | str = _DEFAULT_REGISTRY_PATH) -> None:
self._path = Path(path)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def register(
self,
label: str,
sample_path: str,
weight_path: str | None = None,
) -> dict[str, Any]:
"""
Register a new generator-class task.
Returns the new registry entry dict.
Raises ValueError if *label* already registered.
"""
tasks = self._load()
existing = [t for t in tasks if t["label"] == label]
if existing:
raise ValueError(
f"Task '{label}' already registered. "
f"Use update() to modify it."
)
entry: dict[str, Any] = {
"label": label,
"sample_path": sample_path,
"registered_at": datetime.now(timezone.utc).isoformat(),
"weight_path": weight_path,
}
tasks.append(entry)
self._save(tasks)
logger.info("Registered continual task: %s", label)
return entry
def update(self, label: str, **fields: Any) -> dict[str, Any]:
"""Update fields on an existing task entry."""
tasks = self._load()
for task in tasks:
if task["label"] == label:
task.update(fields)
self._save(tasks)
return task
raise KeyError(f"Task '{label}' not found in registry.")
def all_tasks(self) -> list[dict[str, Any]]:
"""Return all registered tasks."""
return self._load()
def get(self, label: str) -> dict[str, Any] | None:
"""Return a single task by label, or None."""
for task in self._load():
if task["label"] == label:
return task
return None
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _load(self) -> list[dict[str, Any]]:
if not self._path.exists():
return []
try:
return json.loads(self._path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
logger.warning("Registry file corrupt; starting fresh.")
return []
def _save(self, tasks: list[dict[str, Any]]) -> None:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._path.write_text(
json.dumps(tasks, indent=2, default=str),
encoding="utf-8",
)
|