eee_validator / post_merge.py
deepmage121's picture
post_merge fix with reset option
051d3f3
"""Post-merge actions: update manifest.json and dataset card on PR merge."""
import io
import json
import logging
import os
from datetime import datetime, timezone
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.utils import EntryNotFoundError
from dedup import DATASET_REPO_ID, compute_fingerprint, compute_sha256
logger = logging.getLogger(__name__)
api = HfApi()
def _list_data_dirs(api: HfApi) -> list[str]:
"""List top-level directory names under data/ (these become config/split names)."""
dirs: list[str] = []
for entry in api.list_repo_tree(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
path_in_repo="data",
):
if not hasattr(entry, "rfilename"): # it's a directory
# entry.path is like "data/global-mmlu-lite"
name = entry.path.split("/", 1)[-1]
dirs.append(name)
return sorted(dirs)
def _build_dataset_card(configs: list[str]) -> str:
"""Build a dataset card README.md with YAML frontmatter for the viewer."""
yaml_configs = []
for config in configs:
yaml_configs.append(f" - config_name: {config}")
yaml_configs.append(f" data_files:")
yaml_configs.append(f" - split: train")
yaml_configs.append(f" path: data/{config}/**/*.json")
yaml_block = "\n".join(yaml_configs)
return f"""---
configs:
{yaml_block}
license: mit
---
# EEE Datastore
Evaluation data for the EEE project.
"""
def update_manifest(api: HfApi, merged_files: list[str]) -> None:
"""Download merged files from main, compute hashes, and update manifest.json."""
# Load existing manifest
try:
manifest_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename="manifest.json",
repo_type="dataset",
revision="main",
)
with open(manifest_path, "r") as f:
manifest = json.load(f)
except (EntryNotFoundError, Exception):
manifest = {"files": {}}
now = datetime.now(timezone.utc).isoformat()
for file_path in merged_files:
try:
local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=file_path,
repo_type="dataset",
revision="main",
)
with open(local_path, "rb") as f:
content = f.read()
sha256 = compute_sha256(content)
if file_path.endswith(".json"):
fingerprint = compute_fingerprint(content)
else:
fingerprint = sha256
manifest["files"][file_path] = {
"sha256": sha256,
"fingerprint": fingerprint,
"added_at": now,
}
logger.info("Added %s to manifest", file_path)
except Exception:
logger.exception("Failed to process %s for manifest", file_path)
# Upload updated manifest
manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode()
api.upload_file(
path_or_fileobj=io.BytesIO(manifest_bytes),
path_in_repo="manifest.json",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update manifest.json after PR merge",
)
logger.info("Uploaded updated manifest.json (%d files)", len(manifest["files"]))
def update_dataset_card(api: HfApi) -> None:
"""Regenerate the dataset card README.md with configs for all data/ subdirs."""
configs = _list_data_dirs(api)
if not configs:
logger.warning("No data directories found, skipping dataset card update")
return
card_content = _build_dataset_card(configs)
api.upload_file(
path_or_fileobj=io.BytesIO(card_content.encode()),
path_in_repo="README.md",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update dataset card with configs for viewer",
)
logger.info("Updated dataset card with %d configs: %s", len(configs), configs)
def full_rebuild() -> dict:
"""Rebuild manifest.json from scratch and regenerate the dataset card.
Downloads the entire dataset at once via snapshot_download, then walks
the local directory to compute hashes/fingerprints for all data files.
"""
logger.info("Starting full rebuild of manifest + dataset card")
# Download entire dataset in one shot
local_dir = snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
)
logger.info("Downloaded dataset snapshot to %s", local_dir)
# Walk local data/ directory to find all data files
data_root = os.path.join(local_dir, "data")
now = datetime.now(timezone.utc).isoformat()
manifest = {"files": {}}
if not os.path.isdir(data_root):
logger.warning("No data/ directory found in snapshot")
else:
for dirpath, _, filenames in os.walk(data_root):
for filename in filenames:
if not (filename.endswith(".json") or filename.endswith(".jsonl")):
continue
local_path = os.path.join(dirpath, filename)
# Convert to repo-relative path (data/...)
repo_path = os.path.relpath(local_path, local_dir)
try:
with open(local_path, "rb") as f:
content = f.read()
sha256 = compute_sha256(content)
if filename.endswith(".json"):
fingerprint = compute_fingerprint(content)
else:
fingerprint = sha256
manifest["files"][repo_path] = {
"sha256": sha256,
"fingerprint": fingerprint,
"added_at": now,
}
logger.info("Indexed %s", repo_path)
except Exception:
logger.exception("Failed to index %s", repo_path)
logger.info("Indexed %d data files total", len(manifest["files"]))
# Upload fresh manifest
manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode()
api.upload_file(
path_or_fileobj=io.BytesIO(manifest_bytes),
path_in_repo="manifest.json",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Full rebuild of manifest.json",
)
logger.info("Uploaded rebuilt manifest.json (%d files)", len(manifest["files"]))
# Regenerate dataset card
update_dataset_card(api)
return {
"status": "ok",
"action": "full_rebuild",
"files_indexed": len(manifest["files"]),
}
def handle_merge(pr_num: int) -> dict:
"""Run all post-merge actions for a PR."""
logger.info("Handling merge for PR #%d", pr_num)
# Find which data files were added/changed by this PR
# After merge, everything is on main, so we list all data files
# and update manifest for any that aren't already tracked
try:
manifest_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename="manifest.json",
repo_type="dataset",
revision="main",
)
with open(manifest_path, "r") as f:
manifest = json.load(f)
except (EntryNotFoundError, Exception):
manifest = {"files": {}}
# List all data files on main and find untracked ones
all_data_files: list[str] = []
for entry in api.list_repo_tree(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
recursive=True,
):
if not hasattr(entry, "rfilename"):
continue
path = entry.rfilename
if path.startswith("data/") and (path.endswith(".json") or path.endswith(".jsonl")):
all_data_files.append(path)
untracked = [f for f in all_data_files if f not in manifest.get("files", {})]
logger.info("Found %d untracked data files after merge", len(untracked))
if untracked:
update_manifest(api, untracked)
update_dataset_card(api)
return {
"status": "ok",
"action": "post_merge",
"pr": pr_num,
"files_added_to_manifest": len(untracked),
}
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
result = full_rebuild()
print(json.dumps(result, indent=2))