File size: 3,248 Bytes
b3112c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Schema definitions for the 15-feature transaction format.

Loads data/schema.yaml into typed dataclasses. Used by both the tokenizer
and the generator to ensure consistent feature handling.
"""

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import yaml


MASK_TOKEN: int = 0
OOV_TOKEN: int = 1
NULL_TOKEN: int = 2
VALUES_START: int = 3


@dataclass(frozen=True)
class FeatureSchema:
    """Schema for a single transaction feature."""

    name: str
    family: str
    type: str
    num_values: int
    vocab_size: int
    description: str = ""
    bucket_range: Optional[tuple[float, float]] = None
    bucket_method: Optional[str] = None
    values: Optional[dict[int, str]] = None

    def __post_init__(self) -> None:
        assert self.vocab_size >= VALUES_START, (
            f"Feature '{self.name}': vocab_size={self.vocab_size}, "
            f"minimum is {VALUES_START} (MASK + OOV + NULL)"
        )
        assert self.vocab_size == self.num_values + VALUES_START, (
            f"Feature '{self.name}': vocab_size ({self.vocab_size}) != "
            f"num_values ({self.num_values}) + {VALUES_START}"
        )
        if self.type == "bucketed":
            assert self.bucket_range is not None, (
                f"Feature '{self.name}': bucketed type requires bucket_range"
            )
            assert self.bucket_method in ("quantile", "uniform"), (
                f"Feature '{self.name}': bucket_method must be quantile or uniform"
            )


@dataclass(frozen=True)
class SchemaConfig:
    """Complete schema for all features in a transaction."""

    num_features: int
    num_transactions: int
    features: tuple[FeatureSchema, ...]

    def __post_init__(self) -> None:
        assert len(self.features) == self.num_features, (
            f"Expected {self.num_features} features, got {len(self.features)}"
        )

    def feature_names(self) -> list[str]:
        return [f.name for f in self.features]

    def feature_index(self, name: str) -> int:
        for i, f in enumerate(self.features):
            if f.name == name:
                return i
        raise KeyError(f"Unknown feature: {name}")

    def get_feature(self, name: str) -> FeatureSchema:
        return self.features[self.feature_index(name)]


def load_schema(path: str | Path) -> SchemaConfig:
    """Load schema from data/schema.yaml into typed dataclasses."""
    with open(path) as fh:
        raw = yaml.safe_load(fh)

    features: list[FeatureSchema] = []
    for f in raw["features"]:
        bucket_range = None
        if "bucket_range" in f:
            bucket_range = (float(f["bucket_range"][0]), float(f["bucket_range"][1]))
        features.append(FeatureSchema(
            name=f["name"],
            family=f["family"],
            type=f["type"],
            num_values=f["num_values"],
            vocab_size=f["vocab_size"],
            description=f.get("description", ""),
            bucket_range=bucket_range,
            bucket_method=f.get("bucket_method"),
            values=f.get("values"),
        ))

    return SchemaConfig(
        num_features=raw["num_features"],
        num_transactions=raw["num_transactions"],
        features=tuple(features),
    )