cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
3.25 kB
"""Schema definitions for the 15-feature transaction format.
Loads data/schema.yaml into typed dataclasses. Used by both the tokenizer
and the generator to ensure consistent feature handling.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import yaml
MASK_TOKEN: int = 0
OOV_TOKEN: int = 1
NULL_TOKEN: int = 2
VALUES_START: int = 3
@dataclass(frozen=True)
class FeatureSchema:
"""Schema for a single transaction feature."""
name: str
family: str
type: str
num_values: int
vocab_size: int
description: str = ""
bucket_range: Optional[tuple[float, float]] = None
bucket_method: Optional[str] = None
values: Optional[dict[int, str]] = None
def __post_init__(self) -> None:
assert self.vocab_size >= VALUES_START, (
f"Feature '{self.name}': vocab_size={self.vocab_size}, "
f"minimum is {VALUES_START} (MASK + OOV + NULL)"
)
assert self.vocab_size == self.num_values + VALUES_START, (
f"Feature '{self.name}': vocab_size ({self.vocab_size}) != "
f"num_values ({self.num_values}) + {VALUES_START}"
)
if self.type == "bucketed":
assert self.bucket_range is not None, (
f"Feature '{self.name}': bucketed type requires bucket_range"
)
assert self.bucket_method in ("quantile", "uniform"), (
f"Feature '{self.name}': bucket_method must be quantile or uniform"
)
@dataclass(frozen=True)
class SchemaConfig:
"""Complete schema for all features in a transaction."""
num_features: int
num_transactions: int
features: tuple[FeatureSchema, ...]
def __post_init__(self) -> None:
assert len(self.features) == self.num_features, (
f"Expected {self.num_features} features, got {len(self.features)}"
)
def feature_names(self) -> list[str]:
return [f.name for f in self.features]
def feature_index(self, name: str) -> int:
for i, f in enumerate(self.features):
if f.name == name:
return i
raise KeyError(f"Unknown feature: {name}")
def get_feature(self, name: str) -> FeatureSchema:
return self.features[self.feature_index(name)]
def load_schema(path: str | Path) -> SchemaConfig:
"""Load schema from data/schema.yaml into typed dataclasses."""
with open(path) as fh:
raw = yaml.safe_load(fh)
features: list[FeatureSchema] = []
for f in raw["features"]:
bucket_range = None
if "bucket_range" in f:
bucket_range = (float(f["bucket_range"][0]), float(f["bucket_range"][1]))
features.append(FeatureSchema(
name=f["name"],
family=f["family"],
type=f["type"],
num_values=f["num_values"],
vocab_size=f["vocab_size"],
description=f.get("description", ""),
bucket_range=bucket_range,
bucket_method=f.get("bucket_method"),
values=f.get("values"),
))
return SchemaConfig(
num_features=raw["num_features"],
num_transactions=raw["num_transactions"],
features=tuple(features),
)