File size: 7,161 Bytes
fcc4b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
"""Validate public data registries for the WildFIRE-FM release.

The registries are meant to be read by downstream data adapters. This script
keeps references explicit so future task additions do not silently introduce
unregistered sources, targets, grids, or masks.
"""

from __future__ import annotations

import argparse
import re
import sys
from pathlib import Path
from typing import Any

import yaml


REGISTRY_FILES = {
    "sources": "sources.yml",
    "variables": "variables.yml",
    "grids": "grids.yml",
    "tasks": "tasks.yml",
    "splits": "splits.yml",
}

FORBIDDEN_PATTERNS = [
    re.compile(r"/home/"),
    re.compile("/" + "blue" + "/"),
    re.compile("/" + "orange" + "/"),
    re.compile(r"hf_[A-Za-z0-9]{20,}"),
    re.compile(r"(?i)password\s*[:=]"),
    re.compile(r"(?i)secret\s*[:=]"),
    re.compile(r"(?i)token\s*[:=]"),
]

ALLOWED_OBSERVATION_MASK_VALUES = {
    "required",
    "event_label_available",
    "station_observation_available",
    "hms_product_available",
    "weather_truth_available",
    "track_observation_available",
    "usdm_week_available",
}


def load_yaml(path: Path) -> dict[str, Any]:
    with path.open("r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle)
    if not isinstance(data, dict):
        raise ValueError(f"{path} did not load as a mapping")
    return data


def as_list(value: Any) -> list[Any]:
    if value is None:
        return []
    if isinstance(value, list):
        return value
    return [value]


def collect_variable_refs(variables: dict[str, Any]) -> tuple[set[str], set[str], set[str], set[str]]:
    dynamic = set((variables.get("dynamic_weather") or {}).keys())
    static = set((variables.get("static_context") or {}).keys())
    masks = set((variables.get("masks") or {}).keys())
    targets = set((variables.get("targets") or {}).keys())
    return dynamic, static, masks, targets


def check_source_ref(errors: list[str], ref: str, sources: set[str], location: str) -> None:
    if ref not in sources:
        errors.append(f"{location}: unknown source '{ref}'")


def check_forbidden_text(errors: list[str], registry_dir: Path) -> None:
    for path in sorted(registry_dir.glob("*.yml")):
        text = path.read_text(encoding="utf-8")
        for pattern in FORBIDDEN_PATTERNS:
            if pattern.search(text):
                errors.append(f"{path.name}: forbidden local path or credential-like value matched {pattern.pattern!r}")


def validate_variables(errors: list[str], variables: dict[str, Any], source_ids: set[str]) -> None:
    for group_name in ("dynamic_weather", "static_context", "targets"):
        group = variables.get(group_name) or {}
        if not isinstance(group, dict):
            errors.append(f"variables.yml:{group_name} must be a mapping")
            continue
        for name, spec in group.items():
            if not isinstance(spec, dict):
                errors.append(f"variables.yml:{group_name}.{name} must be a mapping")
                continue
            refs = []
            refs.extend(as_list(spec.get("source")))
            refs.extend(as_list(spec.get("source_candidates")))
            for ref in refs:
                check_source_ref(errors, str(ref), source_ids, f"variables.yml:{group_name}.{name}")


def validate_tasks(
    errors: list[str],
    tasks: dict[str, Any],
    source_ids: set[str],
    grid_ids: set[str],
    target_ids: set[str],
    mask_ids: set[str],
) -> None:
    task_specs = tasks.get("tasks") or {}
    if not isinstance(task_specs, dict):
        errors.append("tasks.yml:tasks must be a mapping")
        return
    for task_id, spec in task_specs.items():
        if not isinstance(spec, dict):
            errors.append(f"tasks.yml:{task_id} must be a mapping")
            continue
        grid = spec.get("input_grid")
        if grid not in grid_ids:
            errors.append(f"tasks.yml:{task_id}.input_grid unknown grid '{grid}'")
        target = spec.get("target")
        if target not in target_ids:
            errors.append(f"tasks.yml:{task_id}.target unknown target '{target}'")
        for ref in as_list(spec.get("dynamic_sources")) + as_list(spec.get("static_sources")):
            check_source_ref(errors, str(ref), source_ids, f"tasks.yml:{task_id}")
        observation_mask = spec.get("observation_mask")
        if (
            observation_mask
            and observation_mask not in mask_ids
            and observation_mask not in ALLOWED_OBSERVATION_MASK_VALUES
        ):
            errors.append(f"tasks.yml:{task_id}.observation_mask unknown mask or policy '{observation_mask}'")


def validate_splits(errors: list[str], splits: dict[str, Any], grid_ids: set[str]) -> None:
    split_specs = splits.get("splits") or {}
    if not isinstance(split_specs, dict):
        errors.append("splits.yml:splits must be a mapping")
        return
    for split_id, spec in split_specs.items():
        if not isinstance(spec, dict):
            errors.append(f"splits.yml:{split_id} must be a mapping")
            continue
        grid = spec.get("grid")
        if grid not in grid_ids:
            errors.append(f"splits.yml:{split_id}.grid unknown grid '{grid}'")


def validate(registry_dir: Path) -> list[str]:
    errors: list[str] = []
    loaded: dict[str, dict[str, Any]] = {}
    for key, filename in REGISTRY_FILES.items():
        path = registry_dir / filename
        if not path.exists():
            errors.append(f"missing registry file: {path}")
            continue
        try:
            loaded[key] = load_yaml(path)
        except Exception as exc:  # noqa: BLE001 - keep CLI diagnostics concise.
            errors.append(f"{filename}: failed to parse YAML: {exc}")
    if errors:
        return errors

    source_ids = set((loaded["sources"].get("sources") or {}).keys())
    grid_ids = set((loaded["grids"].get("grids") or {}).keys())
    _, _, mask_ids, target_ids = collect_variable_refs(loaded["variables"])

    if not source_ids:
        errors.append("sources.yml contains no sources")
    if not grid_ids:
        errors.append("grids.yml contains no grids")
    if not target_ids:
        errors.append("variables.yml contains no targets")

    check_forbidden_text(errors, registry_dir)
    validate_variables(errors, loaded["variables"], source_ids)
    validate_tasks(errors, loaded["tasks"], source_ids, grid_ids, target_ids, mask_ids)
    validate_splits(errors, loaded["splits"], grid_ids)
    return errors


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--registry-dir",
        default="registries",
        type=Path,
        help="Directory containing sources.yml, variables.yml, grids.yml, tasks.yml, and splits.yml.",
    )
    args = parser.parse_args(argv)

    errors = validate(args.registry_dir)
    if errors:
        print("Registry validation failed:")
        for error in errors:
            print(f"  - {error}")
        return 1
    print(f"Registry validation passed: {args.registry_dir}")
    return 0


if __name__ == "__main__":
    sys.exit(main())