"""Post-merge actions: update manifest.json and dataset card on PR merge.""" import io import json import logging import os from datetime import datetime, timezone from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.utils import EntryNotFoundError from dedup import DATASET_REPO_ID, compute_fingerprint, compute_sha256 logger = logging.getLogger(__name__) api = HfApi() def _list_data_dirs(api: HfApi) -> list[str]: """List top-level directory names under data/ (these become config/split names).""" dirs: list[str] = [] for entry in api.list_repo_tree( repo_id=DATASET_REPO_ID, repo_type="dataset", revision="main", path_in_repo="data", ): if not hasattr(entry, "rfilename"): # it's a directory # entry.path is like "data/global-mmlu-lite" name = entry.path.split("/", 1)[-1] dirs.append(name) return sorted(dirs) def _build_dataset_card(configs: list[str]) -> str: """Build a dataset card README.md with YAML frontmatter for the viewer.""" yaml_configs = [] for config in configs: yaml_configs.append(f" - config_name: {config}") yaml_configs.append(f" data_files:") yaml_configs.append(f" - split: train") yaml_configs.append(f" path: data/{config}/**/*.json") yaml_block = "\n".join(yaml_configs) return f"""--- configs: {yaml_block} license: mit --- # EEE Datastore Evaluation data for the EEE project. """ def update_manifest(api: HfApi, merged_files: list[str]) -> None: """Download merged files from main, compute hashes, and update manifest.json.""" # Load existing manifest try: manifest_path = hf_hub_download( repo_id=DATASET_REPO_ID, filename="manifest.json", repo_type="dataset", revision="main", ) with open(manifest_path, "r") as f: manifest = json.load(f) except (EntryNotFoundError, Exception): manifest = {"files": {}} now = datetime.now(timezone.utc).isoformat() for file_path in merged_files: try: local_path = hf_hub_download( repo_id=DATASET_REPO_ID, filename=file_path, repo_type="dataset", revision="main", ) with open(local_path, "rb") as f: content = f.read() sha256 = compute_sha256(content) if file_path.endswith(".json"): fingerprint = compute_fingerprint(content) else: fingerprint = sha256 manifest["files"][file_path] = { "sha256": sha256, "fingerprint": fingerprint, "added_at": now, } logger.info("Added %s to manifest", file_path) except Exception: logger.exception("Failed to process %s for manifest", file_path) # Upload updated manifest manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode() api.upload_file( path_or_fileobj=io.BytesIO(manifest_bytes), path_in_repo="manifest.json", repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message="Update manifest.json after PR merge", ) logger.info("Uploaded updated manifest.json (%d files)", len(manifest["files"])) def update_dataset_card(api: HfApi) -> None: """Regenerate the dataset card README.md with configs for all data/ subdirs.""" configs = _list_data_dirs(api) if not configs: logger.warning("No data directories found, skipping dataset card update") return card_content = _build_dataset_card(configs) api.upload_file( path_or_fileobj=io.BytesIO(card_content.encode()), path_in_repo="README.md", repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message="Update dataset card with configs for viewer", ) logger.info("Updated dataset card with %d configs: %s", len(configs), configs) def full_rebuild() -> dict: """Rebuild manifest.json from scratch and regenerate the dataset card. Downloads the entire dataset at once via snapshot_download, then walks the local directory to compute hashes/fingerprints for all data files. """ logger.info("Starting full rebuild of manifest + dataset card") # Download entire dataset in one shot local_dir = snapshot_download( repo_id=DATASET_REPO_ID, repo_type="dataset", revision="main", ) logger.info("Downloaded dataset snapshot to %s", local_dir) # Walk local data/ directory to find all data files data_root = os.path.join(local_dir, "data") now = datetime.now(timezone.utc).isoformat() manifest = {"files": {}} if not os.path.isdir(data_root): logger.warning("No data/ directory found in snapshot") else: for dirpath, _, filenames in os.walk(data_root): for filename in filenames: if not (filename.endswith(".json") or filename.endswith(".jsonl")): continue local_path = os.path.join(dirpath, filename) # Convert to repo-relative path (data/...) repo_path = os.path.relpath(local_path, local_dir) try: with open(local_path, "rb") as f: content = f.read() sha256 = compute_sha256(content) if filename.endswith(".json"): fingerprint = compute_fingerprint(content) else: fingerprint = sha256 manifest["files"][repo_path] = { "sha256": sha256, "fingerprint": fingerprint, "added_at": now, } logger.info("Indexed %s", repo_path) except Exception: logger.exception("Failed to index %s", repo_path) logger.info("Indexed %d data files total", len(manifest["files"])) # Upload fresh manifest manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode() api.upload_file( path_or_fileobj=io.BytesIO(manifest_bytes), path_in_repo="manifest.json", repo_id=DATASET_REPO_ID, repo_type="dataset", commit_message="Full rebuild of manifest.json", ) logger.info("Uploaded rebuilt manifest.json (%d files)", len(manifest["files"])) # Regenerate dataset card update_dataset_card(api) return { "status": "ok", "action": "full_rebuild", "files_indexed": len(manifest["files"]), } def handle_merge(pr_num: int) -> dict: """Run all post-merge actions for a PR.""" logger.info("Handling merge for PR #%d", pr_num) # Find which data files were added/changed by this PR # After merge, everything is on main, so we list all data files # and update manifest for any that aren't already tracked try: manifest_path = hf_hub_download( repo_id=DATASET_REPO_ID, filename="manifest.json", repo_type="dataset", revision="main", ) with open(manifest_path, "r") as f: manifest = json.load(f) except (EntryNotFoundError, Exception): manifest = {"files": {}} # List all data files on main and find untracked ones all_data_files: list[str] = [] for entry in api.list_repo_tree( repo_id=DATASET_REPO_ID, repo_type="dataset", revision="main", recursive=True, ): if not hasattr(entry, "rfilename"): continue path = entry.rfilename if path.startswith("data/") and (path.endswith(".json") or path.endswith(".jsonl")): all_data_files.append(path) untracked = [f for f in all_data_files if f not in manifest.get("files", {})] logger.info("Found %d untracked data files after merge", len(untracked)) if untracked: update_manifest(api, untracked) update_dataset_card(api) return { "status": "ok", "action": "post_merge", "pr": pr_num, "files_added_to_manifest": len(untracked), } if __name__ == "__main__": logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) result = full_rebuild() print(json.dumps(result, indent=2))