| """Schema definitions for the 15-feature transaction format. |
| |
| Loads data/schema.yaml into typed dataclasses. Used by both the tokenizer |
| and the generator to ensure consistent feature handling. |
| """ |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import yaml |
|
|
|
|
| MASK_TOKEN: int = 0 |
| OOV_TOKEN: int = 1 |
| NULL_TOKEN: int = 2 |
| VALUES_START: int = 3 |
|
|
|
|
| @dataclass(frozen=True) |
| class FeatureSchema: |
| """Schema for a single transaction feature.""" |
|
|
| name: str |
| family: str |
| type: str |
| num_values: int |
| vocab_size: int |
| description: str = "" |
| bucket_range: Optional[tuple[float, float]] = None |
| bucket_method: Optional[str] = None |
| values: Optional[dict[int, str]] = None |
|
|
| def __post_init__(self) -> None: |
| assert self.vocab_size >= VALUES_START, ( |
| f"Feature '{self.name}': vocab_size={self.vocab_size}, " |
| f"minimum is {VALUES_START} (MASK + OOV + NULL)" |
| ) |
| assert self.vocab_size == self.num_values + VALUES_START, ( |
| f"Feature '{self.name}': vocab_size ({self.vocab_size}) != " |
| f"num_values ({self.num_values}) + {VALUES_START}" |
| ) |
| if self.type == "bucketed": |
| assert self.bucket_range is not None, ( |
| f"Feature '{self.name}': bucketed type requires bucket_range" |
| ) |
| assert self.bucket_method in ("quantile", "uniform"), ( |
| f"Feature '{self.name}': bucket_method must be quantile or uniform" |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class SchemaConfig: |
| """Complete schema for all features in a transaction.""" |
|
|
| num_features: int |
| num_transactions: int |
| features: tuple[FeatureSchema, ...] |
|
|
| def __post_init__(self) -> None: |
| assert len(self.features) == self.num_features, ( |
| f"Expected {self.num_features} features, got {len(self.features)}" |
| ) |
|
|
| def feature_names(self) -> list[str]: |
| return [f.name for f in self.features] |
|
|
| def feature_index(self, name: str) -> int: |
| for i, f in enumerate(self.features): |
| if f.name == name: |
| return i |
| raise KeyError(f"Unknown feature: {name}") |
|
|
| def get_feature(self, name: str) -> FeatureSchema: |
| return self.features[self.feature_index(name)] |
|
|
|
|
| def load_schema(path: str | Path) -> SchemaConfig: |
| """Load schema from data/schema.yaml into typed dataclasses.""" |
| with open(path) as fh: |
| raw = yaml.safe_load(fh) |
|
|
| features: list[FeatureSchema] = [] |
| for f in raw["features"]: |
| bucket_range = None |
| if "bucket_range" in f: |
| bucket_range = (float(f["bucket_range"][0]), float(f["bucket_range"][1])) |
| features.append(FeatureSchema( |
| name=f["name"], |
| family=f["family"], |
| type=f["type"], |
| num_values=f["num_values"], |
| vocab_size=f["vocab_size"], |
| description=f.get("description", ""), |
| bucket_range=bucket_range, |
| bucket_method=f.get("bucket_method"), |
| values=f.get("values"), |
| )) |
|
|
| return SchemaConfig( |
| num_features=raw["num_features"], |
| num_transactions=raw["num_transactions"], |
| features=tuple(features), |
| ) |
|
|