Spaces:
Sleeping
Sleeping
File size: 8,544 Bytes
e4c8b1d 051d3f3 e4c8b1d 051d3f3 e4c8b1d 0591d06 e4c8b1d d7012da 051d3f3 d7012da 051d3f3 d7012da 051d3f3 d7012da 051d3f3 d7012da 051d3f3 d7012da e4c8b1d d7012da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 | """Post-merge actions: update manifest.json and dataset card on PR merge."""
import io
import json
import logging
import os
from datetime import datetime, timezone
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.utils import EntryNotFoundError
from dedup import DATASET_REPO_ID, compute_fingerprint, compute_sha256
logger = logging.getLogger(__name__)
api = HfApi()
def _list_data_dirs(api: HfApi) -> list[str]:
"""List top-level directory names under data/ (these become config/split names)."""
dirs: list[str] = []
for entry in api.list_repo_tree(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
path_in_repo="data",
):
if not hasattr(entry, "rfilename"): # it's a directory
# entry.path is like "data/global-mmlu-lite"
name = entry.path.split("/", 1)[-1]
dirs.append(name)
return sorted(dirs)
def _build_dataset_card(configs: list[str]) -> str:
"""Build a dataset card README.md with YAML frontmatter for the viewer."""
yaml_configs = []
for config in configs:
yaml_configs.append(f" - config_name: {config}")
yaml_configs.append(f" data_files:")
yaml_configs.append(f" - split: train")
yaml_configs.append(f" path: data/{config}/**/*.json")
yaml_block = "\n".join(yaml_configs)
return f"""---
configs:
{yaml_block}
license: mit
---
# EEE Datastore
Evaluation data for the EEE project.
"""
def update_manifest(api: HfApi, merged_files: list[str]) -> None:
"""Download merged files from main, compute hashes, and update manifest.json."""
# Load existing manifest
try:
manifest_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename="manifest.json",
repo_type="dataset",
revision="main",
)
with open(manifest_path, "r") as f:
manifest = json.load(f)
except (EntryNotFoundError, Exception):
manifest = {"files": {}}
now = datetime.now(timezone.utc).isoformat()
for file_path in merged_files:
try:
local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename=file_path,
repo_type="dataset",
revision="main",
)
with open(local_path, "rb") as f:
content = f.read()
sha256 = compute_sha256(content)
if file_path.endswith(".json"):
fingerprint = compute_fingerprint(content)
else:
fingerprint = sha256
manifest["files"][file_path] = {
"sha256": sha256,
"fingerprint": fingerprint,
"added_at": now,
}
logger.info("Added %s to manifest", file_path)
except Exception:
logger.exception("Failed to process %s for manifest", file_path)
# Upload updated manifest
manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode()
api.upload_file(
path_or_fileobj=io.BytesIO(manifest_bytes),
path_in_repo="manifest.json",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update manifest.json after PR merge",
)
logger.info("Uploaded updated manifest.json (%d files)", len(manifest["files"]))
def update_dataset_card(api: HfApi) -> None:
"""Regenerate the dataset card README.md with configs for all data/ subdirs."""
configs = _list_data_dirs(api)
if not configs:
logger.warning("No data directories found, skipping dataset card update")
return
card_content = _build_dataset_card(configs)
api.upload_file(
path_or_fileobj=io.BytesIO(card_content.encode()),
path_in_repo="README.md",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update dataset card with configs for viewer",
)
logger.info("Updated dataset card with %d configs: %s", len(configs), configs)
def full_rebuild() -> dict:
"""Rebuild manifest.json from scratch and regenerate the dataset card.
Downloads the entire dataset at once via snapshot_download, then walks
the local directory to compute hashes/fingerprints for all data files.
"""
logger.info("Starting full rebuild of manifest + dataset card")
# Download entire dataset in one shot
local_dir = snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
)
logger.info("Downloaded dataset snapshot to %s", local_dir)
# Walk local data/ directory to find all data files
data_root = os.path.join(local_dir, "data")
now = datetime.now(timezone.utc).isoformat()
manifest = {"files": {}}
if not os.path.isdir(data_root):
logger.warning("No data/ directory found in snapshot")
else:
for dirpath, _, filenames in os.walk(data_root):
for filename in filenames:
if not (filename.endswith(".json") or filename.endswith(".jsonl")):
continue
local_path = os.path.join(dirpath, filename)
# Convert to repo-relative path (data/...)
repo_path = os.path.relpath(local_path, local_dir)
try:
with open(local_path, "rb") as f:
content = f.read()
sha256 = compute_sha256(content)
if filename.endswith(".json"):
fingerprint = compute_fingerprint(content)
else:
fingerprint = sha256
manifest["files"][repo_path] = {
"sha256": sha256,
"fingerprint": fingerprint,
"added_at": now,
}
logger.info("Indexed %s", repo_path)
except Exception:
logger.exception("Failed to index %s", repo_path)
logger.info("Indexed %d data files total", len(manifest["files"]))
# Upload fresh manifest
manifest_bytes = json.dumps(manifest, indent=2, sort_keys=True).encode()
api.upload_file(
path_or_fileobj=io.BytesIO(manifest_bytes),
path_in_repo="manifest.json",
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Full rebuild of manifest.json",
)
logger.info("Uploaded rebuilt manifest.json (%d files)", len(manifest["files"]))
# Regenerate dataset card
update_dataset_card(api)
return {
"status": "ok",
"action": "full_rebuild",
"files_indexed": len(manifest["files"]),
}
def handle_merge(pr_num: int) -> dict:
"""Run all post-merge actions for a PR."""
logger.info("Handling merge for PR #%d", pr_num)
# Find which data files were added/changed by this PR
# After merge, everything is on main, so we list all data files
# and update manifest for any that aren't already tracked
try:
manifest_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
filename="manifest.json",
repo_type="dataset",
revision="main",
)
with open(manifest_path, "r") as f:
manifest = json.load(f)
except (EntryNotFoundError, Exception):
manifest = {"files": {}}
# List all data files on main and find untracked ones
all_data_files: list[str] = []
for entry in api.list_repo_tree(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
revision="main",
recursive=True,
):
if not hasattr(entry, "rfilename"):
continue
path = entry.rfilename
if path.startswith("data/") and (path.endswith(".json") or path.endswith(".jsonl")):
all_data_files.append(path)
untracked = [f for f in all_data_files if f not in manifest.get("files", {})]
logger.info("Found %d untracked data files after merge", len(untracked))
if untracked:
update_manifest(api, untracked)
update_dataset_card(api)
return {
"status": "ok",
"action": "post_merge",
"pr": pr_num,
"files_added_to_manifest": len(untracked),
}
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
result = full_rebuild()
print(json.dumps(result, indent=2))
|