File size: 5,307 Bytes
dd6cefc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Upload a curated Objectverse Diary SFT JSONL file to Hugging Face Datasets."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any


def validate_dataset_file(dataset_file: Path) -> dict[str, object]:
    if not dataset_file.exists() or not dataset_file.is_file():
        raise FileNotFoundError(f"Dataset file does not exist: {dataset_file}")

    record_count = 0
    sources: set[str] = set()
    modes: set[str] = set()
    object_names: set[str] = set()
    for line_number, line in enumerate(
        dataset_file.read_text(encoding="utf-8").splitlines(),
        start=1,
    ):
        if not line.strip():
            continue
        try:
            record = json.loads(line)
        except json.JSONDecodeError as exc:
            raise ValueError(f"Invalid JSON on line {line_number}: {exc.msg}") from exc
        if not isinstance(record, dict):
            raise ValueError(f"Line {line_number} must be a JSON object.")
        messages = record.get("messages")
        if not isinstance(messages, list) or not messages:
            raise ValueError(f"Line {line_number} must include a non-empty messages list.")
        assistant_messages = [
            message
            for message in messages
            if isinstance(message, dict) and message.get("role") == "assistant"
        ]
        if not assistant_messages:
            raise ValueError(f"Line {line_number} must include an assistant message.")
        assistant_content = assistant_messages[-1].get("content")
        if not isinstance(assistant_content, str):
            raise ValueError(f"Line {line_number} assistant content must be a string.")
        try:
            assistant_payload = json.loads(assistant_content)
        except json.JSONDecodeError as exc:
            raise ValueError(
                f"Line {line_number} assistant content is not valid JSON: {exc.msg}"
            ) from exc
        if not isinstance(assistant_payload, dict):
            raise ValueError(f"Line {line_number} assistant content must be a JSON object.")
        if "persona" not in assistant_payload or "diary" not in assistant_payload:
            raise ValueError(
                f"Line {line_number} assistant content must include persona and diary."
            )

        record_count += 1
        if isinstance(record.get("source"), str):
            sources.add(str(record["source"]))
        if isinstance(record.get("mode"), str):
            modes.add(str(record["mode"]))
        object_understanding = record.get("object_understanding")
        if isinstance(object_understanding, dict):
            raw_object = object_understanding.get("object")
            if isinstance(raw_object, dict) and isinstance(raw_object.get("name"), str):
                object_names.add(str(raw_object["name"]))

    if record_count == 0:
        raise ValueError(f"Dataset file has no records: {dataset_file}")

    return {
        "dataset_file": str(dataset_file),
        "record_count": record_count,
        "sources": sorted(sources),
        "modes": sorted(modes),
        "unique_object_count": len(object_names),
    }


def upload_dataset(
    *,
    dataset_file: Path,
    repo_id: str,
    path_in_repo: str,
    private: bool,
    commit_message: str,
    dry_run: bool,
) -> dict[str, object]:
    summary = validate_dataset_file(dataset_file)
    summary.update(
        {
            "repo_id": repo_id,
            "path_in_repo": path_in_repo,
            "private": private,
            "commit_message": commit_message,
            "dry_run": dry_run,
        }
    )
    if dry_run:
        summary["uploaded"] = False
        return summary

    from huggingface_hub import HfApi

    api = HfApi()
    api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
    api.upload_file(
        path_or_fileobj=str(dataset_file),
        path_in_repo=path_in_repo,
        repo_id=repo_id,
        repo_type="dataset",
        commit_message=commit_message,
    )
    summary["uploaded"] = True
    summary["url"] = f"https://huggingface.co/datasets/{repo_id}"
    return summary


def _print_json(payload: dict[str, Any]) -> None:
    print(json.dumps(payload, indent=2, sort_keys=True), flush=True)


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--dataset-file", type=Path, required=True)
    parser.add_argument("--repo-id", required=True)
    parser.add_argument("--path-in-repo", required=True)
    parser.add_argument("--private", action="store_true")
    parser.add_argument(
        "--commit-message",
        default="Upload Objectverse Diary curated SFT dataset",
    )
    parser.add_argument("--dry-run", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = _parse_args()
    _print_json(
        upload_dataset(
            dataset_file=args.dataset_file,
            repo_id=args.repo_id,
            path_in_repo=args.path_in_repo,
            private=args.private,
            commit_message=args.commit_message,
            dry_run=args.dry_run,
        )
    )


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:
        raise SystemExit(str(exc)) from exc