"""Schema definitions for the 15-feature transaction format. Loads data/schema.yaml into typed dataclasses. Used by both the tokenizer and the generator to ensure consistent feature handling. """ from dataclasses import dataclass from pathlib import Path from typing import Optional import yaml MASK_TOKEN: int = 0 OOV_TOKEN: int = 1 NULL_TOKEN: int = 2 VALUES_START: int = 3 @dataclass(frozen=True) class FeatureSchema: """Schema for a single transaction feature.""" name: str family: str type: str num_values: int vocab_size: int description: str = "" bucket_range: Optional[tuple[float, float]] = None bucket_method: Optional[str] = None values: Optional[dict[int, str]] = None def __post_init__(self) -> None: assert self.vocab_size >= VALUES_START, ( f"Feature '{self.name}': vocab_size={self.vocab_size}, " f"minimum is {VALUES_START} (MASK + OOV + NULL)" ) assert self.vocab_size == self.num_values + VALUES_START, ( f"Feature '{self.name}': vocab_size ({self.vocab_size}) != " f"num_values ({self.num_values}) + {VALUES_START}" ) if self.type == "bucketed": assert self.bucket_range is not None, ( f"Feature '{self.name}': bucketed type requires bucket_range" ) assert self.bucket_method in ("quantile", "uniform"), ( f"Feature '{self.name}': bucket_method must be quantile or uniform" ) @dataclass(frozen=True) class SchemaConfig: """Complete schema for all features in a transaction.""" num_features: int num_transactions: int features: tuple[FeatureSchema, ...] def __post_init__(self) -> None: assert len(self.features) == self.num_features, ( f"Expected {self.num_features} features, got {len(self.features)}" ) def feature_names(self) -> list[str]: return [f.name for f in self.features] def feature_index(self, name: str) -> int: for i, f in enumerate(self.features): if f.name == name: return i raise KeyError(f"Unknown feature: {name}") def get_feature(self, name: str) -> FeatureSchema: return self.features[self.feature_index(name)] def load_schema(path: str | Path) -> SchemaConfig: """Load schema from data/schema.yaml into typed dataclasses.""" with open(path) as fh: raw = yaml.safe_load(fh) features: list[FeatureSchema] = [] for f in raw["features"]: bucket_range = None if "bucket_range" in f: bucket_range = (float(f["bucket_range"][0]), float(f["bucket_range"][1])) features.append(FeatureSchema( name=f["name"], family=f["family"], type=f["type"], num_values=f["num_values"], vocab_size=f["vocab_size"], description=f.get("description", ""), bucket_range=bucket_range, bucket_method=f.get("bucket_method"), values=f.get("values"), )) return SchemaConfig( num_features=raw["num_features"], num_transactions=raw["num_transactions"], features=tuple(features), )