ObjectverseDiary / scripts /generate_dataset.py
qqyule's picture
feat: add initial mock mvp
bc02199
"""Generate deterministic SFT preview data for Objectverse Diary."""
from __future__ import annotations
import argparse
import json
import sys
from collections.abc import Mapping, Sequence
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.config import PERSONALITY_MODES
from src.examples import EXAMPLE_OBJECTS
from src.pipeline import generate_object_diary
DEFAULT_OUTPUT_PATH = Path("data/train/objectverse_sft_preview.jsonl")
DEFAULT_COUNT = 60
SYSTEM_PROMPT = (
"You are Objectverse Diary, an English-first small-model assistant. "
"Given structured object understanding and a requested personality mode, "
"return JSON containing a consistent hidden object persona and a short diary."
)
SCENE_VARIANTS = [
"after a long day of being ignored",
"beside half-finished notes",
"under warm desk light",
"near a quiet morning window",
"during a suspiciously ordinary afternoon",
"while humans rush past without noticing",
"after surviving another minor household crisis",
"beside a charging cable and old receipts",
"in the corner of a room that knows too much",
"before anyone remembers to clean it",
]
def build_sft_records(count: int = DEFAULT_COUNT) -> list[dict[str, object]]:
if count < 1:
raise ValueError("count must be at least 1")
records: list[dict[str, object]] = []
for index in range(count):
example = EXAMPLE_OBJECTS[index % len(EXAMPLE_OBJECTS)]
mode = _mode_for_index(example, index)
description = _description_for_index(example, index)
result = generate_object_diary(
image_path=None,
description=description,
mode=mode,
save=False,
trace_id=f"sft-preview-{index + 1:04d}",
)
assistant_payload = {
"persona": result.persona.persona.model_dump(mode="json"),
"diary": result.diary.model_dump(mode="json"),
}
record = {
"id": f"sft-preview-{index + 1:04d}",
"source": "objectverse-diary-mock-mvp",
"split": "preview",
"mode": mode,
"object_description": description,
"object_understanding": result.object_understanding.model_dump(mode="json"),
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": _user_prompt(
result.object_understanding.model_dump(mode="json"),
mode,
),
},
{
"role": "assistant",
"content": json.dumps(assistant_payload, ensure_ascii=False),
},
],
}
records.append(record)
return records
def write_sft_jsonl(records: Sequence[Mapping[str, object]], output_path: Path) -> Path:
output_path.parent.mkdir(parents=True, exist_ok=True)
lines = [json.dumps(record, ensure_ascii=False, sort_keys=True) for record in records]
output_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
return output_path
def generate_dataset(output_path: Path = DEFAULT_OUTPUT_PATH, count: int = DEFAULT_COUNT) -> Path:
return write_sft_jsonl(build_sft_records(count), output_path)
def _mode_for_index(example: Mapping[str, str], index: int) -> str:
base_mode = example["mode"]
base_index = PERSONALITY_MODES.index(base_mode) if base_mode in PERSONALITY_MODES else 0
return PERSONALITY_MODES[(base_index + index) % len(PERSONALITY_MODES)]
def _description_for_index(example: Mapping[str, str], index: int) -> str:
scene = SCENE_VARIANTS[index % len(SCENE_VARIANTS)]
return f"{example['description']}, {scene}"
def _user_prompt(object_understanding: Mapping[str, object], mode: str) -> str:
payload = json.dumps(object_understanding, ensure_ascii=False, sort_keys=True)
return (
f"Personality mode: {mode}\n"
f"Object understanding JSON: {payload}\n"
"Return JSON with keys persona and diary."
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--count", type=int, default=DEFAULT_COUNT)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
return parser.parse_args()
def main() -> None:
args = _parse_args()
output_path = generate_dataset(args.output, args.count)
print(f"wrote {args.count} SFT preview records to {output_path}")
if __name__ == "__main__":
main()