from __future__ import annotations from dataclasses import dataclass import torch @dataclass(frozen=True, slots=True) class EmbeddingBank: ids: tuple[str, ...] feats: torch.Tensor # (N, D) normalized @dataclass(frozen=True, slots=True) class LabelSetBank: label_set_hash: str name: str domains: EmbeddingBank labels_by_domain: dict[str, EmbeddingBank] # domain_id -> bank