ObjectverseDiary / scripts /publish_hf_dataset.py
qqyule's picture
Deploy latest Objectverse Diary from fa09aac
dd6cefc verified
"""Upload a curated Objectverse Diary SFT JSONL file to Hugging Face Datasets."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def validate_dataset_file(dataset_file: Path) -> dict[str, object]:
if not dataset_file.exists() or not dataset_file.is_file():
raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}")
record_count = 0
sources: set[str] = set()
modes: set[str] = set()
object_names: set[str] = set()
for line_number, line in enumerate(
dataset_file.read_text(encoding="utf-8").splitlines(),
start=1,
):
if not line.strip():
continue
try:
record = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON on line {line_number}: {exc.msg}") from exc
if not isinstance(record, dict):
raise ValueError(f"Line {line_number} must be a JSON object.")
messages = record.get("messages")
if not isinstance(messages, list) or not messages:
raise ValueError(f"Line {line_number} must include a non-empty messages list.")
assistant_messages = [
message
for message in messages
if isinstance(message, dict) and message.get("role") == "assistant"
]
if not assistant_messages:
raise ValueError(f"Line {line_number} must include an assistant message.")
assistant_content = assistant_messages[-1].get("content")
if not isinstance(assistant_content, str):
raise ValueError(f"Line {line_number} assistant content must be a string.")
try:
assistant_payload = json.loads(assistant_content)
except json.JSONDecodeError as exc:
raise ValueError(
f"Line {line_number} assistant content is not valid JSON: {exc.msg}"
) from exc
if not isinstance(assistant_payload, dict):
raise ValueError(f"Line {line_number} assistant content must be a JSON object.")
if "persona" not in assistant_payload or "diary" not in assistant_payload:
raise ValueError(
f"Line {line_number} assistant content must include persona and diary."
)
record_count += 1
if isinstance(record.get("source"), str):
sources.add(str(record["source"]))
if isinstance(record.get("mode"), str):
modes.add(str(record["mode"]))
object_understanding = record.get("object_understanding")
if isinstance(object_understanding, dict):
raw_object = object_understanding.get("object")
if isinstance(raw_object, dict) and isinstance(raw_object.get("name"), str):
object_names.add(str(raw_object["name"]))
if record_count == 0:
raise ValueError(f"Dataset file has no records: {dataset_file}")
return {
"dataset_file": str(dataset_file),
"record_count": record_count,
"sources": sorted(sources),
"modes": sorted(modes),
"unique_object_count": len(object_names),
}
def upload_dataset(
*,
dataset_file: Path,
repo_id: str,
path_in_repo: str,
private: bool,
commit_message: str,
dry_run: bool,
) -> dict[str, object]:
summary = validate_dataset_file(dataset_file)
summary.update(
{
"repo_id": repo_id,
"path_in_repo": path_in_repo,
"private": private,
"commit_message": commit_message,
"dry_run": dry_run,
}
)
if dry_run:
summary["uploaded"] = False
return summary
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
api.upload_file(
path_or_fileobj=str(dataset_file),
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
summary["uploaded"] = True
summary["url"] = f"https://huggingface.co/datasets/{repo_id}"
return summary
def _print_json(payload: dict[str, Any]) -> None:
print(json.dumps(payload, indent=2, sort_keys=True), flush=True)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--dataset-file", type=Path, required=True)
parser.add_argument("--repo-id", required=True)
parser.add_argument("--path-in-repo", required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument(
"--commit-message",
default="Upload Objectverse Diary curated SFT dataset",
)
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args()
def main() -> None:
args = _parse_args()
_print_json(
upload_dataset(
dataset_file=args.dataset_file,
repo_id=args.repo_id,
path_in_repo=args.path_in_repo,
private=args.private,
commit_message=args.commit_message,
dry_run=args.dry_run,
)
)
if __name__ == "__main__":
try:
main()
except Exception as exc:
raise SystemExit(str(exc)) from exc