msrishav's picture
Add inference code, config, and technical report
e68eb1c verified
Raw
History Blame Contribute Delete
7.49 kB
"""Official-protocol checks, dataset manifests, and locked sequence folds."""
from __future__ import annotations
import hashlib
import json
import subprocess
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Iterable, Optional
import numpy as np
import pandas as pd
REQUIRED_META_COLUMNS = ["seq_ix", "step_in_seq", "need_prediction"]
EXPECTED_N_FEATURES = 32
EXPECTED_SEQUENCE_LENGTH = 1000
@dataclass(frozen=True)
class DatasetManifest:
path: str
sha256: str
size_bytes: int
row_count: int
sequence_count: int
sequence_length: int
feature_columns: list[str]
need_prediction_counts_by_step: dict[str, int]
seq_ix_min: int
seq_ix_max: int
@dataclass(frozen=True)
class FoldManifest:
seed: int
n_folds: int
final_holdout_fold: int
folds: dict[str, list[int]]
train_dev_seq_ids: list[int]
final_holdout_seq_ids: list[int]
dataset_sha256: Optional[str] = None
git_commit: Optional[str] = None
def resolve_dataset_path(preferred: str | Path = "data/raw/train.parquet") -> Path:
"""Find the canonical raw train parquet without mutating the workspace."""
candidates = [
Path(preferred),
Path("data/raw/train.parquet"),
Path("data/train.parquet"),
Path("competition_package/datasets/train.parquet"),
Path("wnn_starterpack/competition_package/datasets/train.parquet"),
Path("datasets/train.parquet"),
Path("train.parquet"),
]
for candidate in candidates:
if candidate.exists():
return candidate
raise FileNotFoundError(
"No train parquet found. Expected data/raw/train.parquet, "
"data/train.parquet, or competition_package/datasets/train.parquet."
)
def sha256_file(path: str | Path, chunk_size: int = 1024 * 1024) -> str:
digest = hashlib.sha256()
with Path(path).open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
digest.update(chunk)
return digest.hexdigest()
def get_feature_columns(df: pd.DataFrame) -> list[str]:
"""Return feature columns in scorer/parquet order, never lexicographic sort."""
return list(df.columns[len(REQUIRED_META_COLUMNS) :])
def validate_wunder_dataframe(
df: pd.DataFrame,
*,
expected_n_features: int = EXPECTED_N_FEATURES,
expected_sequence_length: int = EXPECTED_SEQUENCE_LENGTH,
) -> list[str]:
"""Validate the official data contract and return feature columns."""
if list(df.columns[:3]) != REQUIRED_META_COLUMNS:
raise ValueError(
"First columns must be seq_ix, step_in_seq, need_prediction; "
f"got {list(df.columns[:3])}"
)
feature_cols = get_feature_columns(df)
if len(feature_cols) != expected_n_features:
raise ValueError(
f"Expected {expected_n_features} feature columns, got {len(feature_cols)}"
)
if df[REQUIRED_META_COLUMNS].isna().any().any():
raise ValueError("Metadata columns contain missing values")
lengths = df.groupby("seq_ix", sort=True).size()
bad_lengths = lengths[lengths != expected_sequence_length]
if not bad_lengths.empty:
raise ValueError(
"Every sequence must have exactly "
f"{expected_sequence_length} rows; bad seq_ix values: "
f"{bad_lengths.index.tolist()[:10]}"
)
expected_steps = np.arange(expected_sequence_length)
for seq_ix, seq_df in df.groupby("seq_ix", sort=False):
steps = seq_df["step_in_seq"].to_numpy()
if not np.array_equal(steps, expected_steps):
raise ValueError(
f"Sequence {seq_ix} must be ordered with steps 0.."
f"{expected_sequence_length - 1}"
)
if df[feature_cols].isna().any().any():
raise ValueError("Feature columns contain missing values")
return feature_cols
def load_wunder_dataframe(path: str | Path, seq_ids: Optional[Iterable[int]] = None) -> pd.DataFrame:
"""Read and validate the Wunder parquet file."""
df = pd.read_parquet(path)
if seq_ids is not None:
seq_set = set(int(seq_id) for seq_id in seq_ids)
df = df[df["seq_ix"].isin(seq_set)].copy()
df = df.sort_values(["seq_ix", "step_in_seq"]).reset_index(drop=True)
validate_wunder_dataframe(df)
return df
def create_dataset_manifest(path: str | Path) -> DatasetManifest:
"""Build a reproducibility manifest for a validated parquet file."""
path = Path(path)
df = load_wunder_dataframe(path)
feature_cols = get_feature_columns(df)
need_counts = (
df.groupby("step_in_seq")["need_prediction"]
.sum()
.astype(int)
.to_dict()
)
seq_ids = sorted(int(x) for x in df["seq_ix"].unique())
return DatasetManifest(
path=str(path),
sha256=sha256_file(path),
size_bytes=path.stat().st_size,
row_count=int(len(df)),
sequence_count=len(seq_ids),
sequence_length=int(df.groupby("seq_ix").size().iloc[0]),
feature_columns=feature_cols,
need_prediction_counts_by_step={str(k): int(v) for k, v in need_counts.items()},
seq_ix_min=min(seq_ids),
seq_ix_max=max(seq_ids),
)
def write_json(path: str | Path, payload: object) -> None:
"""Write dataclasses/dicts as stable JSON."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(payload, "__dataclass_fields__"):
payload = asdict(payload)
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def read_json(path: str | Path) -> dict:
return json.loads(Path(path).read_text(encoding="utf-8"))
def current_git_commit() -> Optional[str]:
try:
return subprocess.check_output(
["git", "rev-parse", "HEAD"],
text=True,
stderr=subprocess.DEVNULL,
).strip()
except Exception:
return None
def create_sequence_folds(
seq_ids: Iterable[int],
*,
n_folds: int = 5,
seed: int = 42,
final_holdout_fold: int = 4,
dataset_sha256: Optional[str] = None,
) -> FoldManifest:
"""Create deterministic folds by seq_ix with one untouched final holdout fold."""
seq_ids = np.array(sorted(int(x) for x in seq_ids), dtype=int)
if len(seq_ids) < n_folds:
raise ValueError("Need at least as many sequences as folds")
if final_holdout_fold < 0 or final_holdout_fold >= n_folds:
raise ValueError("final_holdout_fold must be in [0, n_folds)")
rng = np.random.RandomState(seed)
shuffled = seq_ids.copy()
rng.shuffle(shuffled)
split_ids = np.array_split(shuffled, n_folds)
folds = {
str(i): sorted(int(x) for x in split_ids[i].tolist())
for i in range(n_folds)
}
final_holdout = folds[str(final_holdout_fold)]
train_dev = sorted(
int(seq_id)
for fold_id, fold_seq_ids in folds.items()
if int(fold_id) != final_holdout_fold
for seq_id in fold_seq_ids
)
return FoldManifest(
seed=seed,
n_folds=n_folds,
final_holdout_fold=final_holdout_fold,
folds=folds,
train_dev_seq_ids=train_dev,
final_holdout_seq_ids=final_holdout,
dataset_sha256=dataset_sha256,
git_commit=current_git_commit(),
)
def load_sequence_folds(path: str | Path) -> FoldManifest:
payload = read_json(path)
return FoldManifest(**payload)