File size: 7,161 Bytes
fcc4b6d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | #!/usr/bin/env python3
"""Validate public data registries for the WildFIRE-FM release.
The registries are meant to be read by downstream data adapters. This script
keeps references explicit so future task additions do not silently introduce
unregistered sources, targets, grids, or masks.
"""
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
from typing import Any
import yaml
REGISTRY_FILES = {
"sources": "sources.yml",
"variables": "variables.yml",
"grids": "grids.yml",
"tasks": "tasks.yml",
"splits": "splits.yml",
}
FORBIDDEN_PATTERNS = [
re.compile(r"/home/"),
re.compile("/" + "blue" + "/"),
re.compile("/" + "orange" + "/"),
re.compile(r"hf_[A-Za-z0-9]{20,}"),
re.compile(r"(?i)password\s*[:=]"),
re.compile(r"(?i)secret\s*[:=]"),
re.compile(r"(?i)token\s*[:=]"),
]
ALLOWED_OBSERVATION_MASK_VALUES = {
"required",
"event_label_available",
"station_observation_available",
"hms_product_available",
"weather_truth_available",
"track_observation_available",
"usdm_week_available",
}
def load_yaml(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as handle:
data = yaml.safe_load(handle)
if not isinstance(data, dict):
raise ValueError(f"{path} did not load as a mapping")
return data
def as_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
return [value]
def collect_variable_refs(variables: dict[str, Any]) -> tuple[set[str], set[str], set[str], set[str]]:
dynamic = set((variables.get("dynamic_weather") or {}).keys())
static = set((variables.get("static_context") or {}).keys())
masks = set((variables.get("masks") or {}).keys())
targets = set((variables.get("targets") or {}).keys())
return dynamic, static, masks, targets
def check_source_ref(errors: list[str], ref: str, sources: set[str], location: str) -> None:
if ref not in sources:
errors.append(f"{location}: unknown source '{ref}'")
def check_forbidden_text(errors: list[str], registry_dir: Path) -> None:
for path in sorted(registry_dir.glob("*.yml")):
text = path.read_text(encoding="utf-8")
for pattern in FORBIDDEN_PATTERNS:
if pattern.search(text):
errors.append(f"{path.name}: forbidden local path or credential-like value matched {pattern.pattern!r}")
def validate_variables(errors: list[str], variables: dict[str, Any], source_ids: set[str]) -> None:
for group_name in ("dynamic_weather", "static_context", "targets"):
group = variables.get(group_name) or {}
if not isinstance(group, dict):
errors.append(f"variables.yml:{group_name} must be a mapping")
continue
for name, spec in group.items():
if not isinstance(spec, dict):
errors.append(f"variables.yml:{group_name}.{name} must be a mapping")
continue
refs = []
refs.extend(as_list(spec.get("source")))
refs.extend(as_list(spec.get("source_candidates")))
for ref in refs:
check_source_ref(errors, str(ref), source_ids, f"variables.yml:{group_name}.{name}")
def validate_tasks(
errors: list[str],
tasks: dict[str, Any],
source_ids: set[str],
grid_ids: set[str],
target_ids: set[str],
mask_ids: set[str],
) -> None:
task_specs = tasks.get("tasks") or {}
if not isinstance(task_specs, dict):
errors.append("tasks.yml:tasks must be a mapping")
return
for task_id, spec in task_specs.items():
if not isinstance(spec, dict):
errors.append(f"tasks.yml:{task_id} must be a mapping")
continue
grid = spec.get("input_grid")
if grid not in grid_ids:
errors.append(f"tasks.yml:{task_id}.input_grid unknown grid '{grid}'")
target = spec.get("target")
if target not in target_ids:
errors.append(f"tasks.yml:{task_id}.target unknown target '{target}'")
for ref in as_list(spec.get("dynamic_sources")) + as_list(spec.get("static_sources")):
check_source_ref(errors, str(ref), source_ids, f"tasks.yml:{task_id}")
observation_mask = spec.get("observation_mask")
if (
observation_mask
and observation_mask not in mask_ids
and observation_mask not in ALLOWED_OBSERVATION_MASK_VALUES
):
errors.append(f"tasks.yml:{task_id}.observation_mask unknown mask or policy '{observation_mask}'")
def validate_splits(errors: list[str], splits: dict[str, Any], grid_ids: set[str]) -> None:
split_specs = splits.get("splits") or {}
if not isinstance(split_specs, dict):
errors.append("splits.yml:splits must be a mapping")
return
for split_id, spec in split_specs.items():
if not isinstance(spec, dict):
errors.append(f"splits.yml:{split_id} must be a mapping")
continue
grid = spec.get("grid")
if grid not in grid_ids:
errors.append(f"splits.yml:{split_id}.grid unknown grid '{grid}'")
def validate(registry_dir: Path) -> list[str]:
errors: list[str] = []
loaded: dict[str, dict[str, Any]] = {}
for key, filename in REGISTRY_FILES.items():
path = registry_dir / filename
if not path.exists():
errors.append(f"missing registry file: {path}")
continue
try:
loaded[key] = load_yaml(path)
except Exception as exc: # noqa: BLE001 - keep CLI diagnostics concise.
errors.append(f"{filename}: failed to parse YAML: {exc}")
if errors:
return errors
source_ids = set((loaded["sources"].get("sources") or {}).keys())
grid_ids = set((loaded["grids"].get("grids") or {}).keys())
_, _, mask_ids, target_ids = collect_variable_refs(loaded["variables"])
if not source_ids:
errors.append("sources.yml contains no sources")
if not grid_ids:
errors.append("grids.yml contains no grids")
if not target_ids:
errors.append("variables.yml contains no targets")
check_forbidden_text(errors, registry_dir)
validate_variables(errors, loaded["variables"], source_ids)
validate_tasks(errors, loaded["tasks"], source_ids, grid_ids, target_ids, mask_ids)
validate_splits(errors, loaded["splits"], grid_ids)
return errors
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--registry-dir",
default="registries",
type=Path,
help="Directory containing sources.yml, variables.yml, grids.yml, tasks.yml, and splits.yml.",
)
args = parser.parse_args(argv)
errors = validate(args.registry_dir)
if errors:
print("Registry validation failed:")
for error in errors:
print(f" - {error}")
return 1
print(f"Registry validation passed: {args.registry_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())
|