File size: 4,717 Bytes
bc02199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Generate deterministic SFT preview data for Objectverse Diary."""

from __future__ import annotations

import argparse
import json
import sys
from collections.abc import Mapping, Sequence
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import PERSONALITY_MODES
from src.examples import EXAMPLE_OBJECTS
from src.pipeline import generate_object_diary


DEFAULT_OUTPUT_PATH = Path("data/train/objectverse_sft_preview.jsonl")
DEFAULT_COUNT = 60

SYSTEM_PROMPT = (
    "You are Objectverse Diary, an English-first small-model assistant. "
    "Given structured object understanding and a requested personality mode, "
    "return JSON containing a consistent hidden object persona and a short diary."
)

SCENE_VARIANTS = [
    "after a long day of being ignored",
    "beside half-finished notes",
    "under warm desk light",
    "near a quiet morning window",
    "during a suspiciously ordinary afternoon",
    "while humans rush past without noticing",
    "after surviving another minor household crisis",
    "beside a charging cable and old receipts",
    "in the corner of a room that knows too much",
    "before anyone remembers to clean it",
]


def build_sft_records(count: int = DEFAULT_COUNT) -> list[dict[str, object]]:
    if count < 1:
        raise ValueError("count must be at least 1")

    records: list[dict[str, object]] = []
    for index in range(count):
        example = EXAMPLE_OBJECTS[index % len(EXAMPLE_OBJECTS)]
        mode = _mode_for_index(example, index)
        description = _description_for_index(example, index)
        result = generate_object_diary(
            image_path=None,
            description=description,
            mode=mode,
            save=False,
            trace_id=f"sft-preview-{index + 1:04d}",
        )
        assistant_payload = {
            "persona": result.persona.persona.model_dump(mode="json"),
            "diary": result.diary.model_dump(mode="json"),
        }
        record = {
            "id": f"sft-preview-{index + 1:04d}",
            "source": "objectverse-diary-mock-mvp",
            "split": "preview",
            "mode": mode,
            "object_description": description,
            "object_understanding": result.object_understanding.model_dump(mode="json"),
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {
                    "role": "user",
                    "content": _user_prompt(
                        result.object_understanding.model_dump(mode="json"),
                        mode,
                    ),
                },
                {
                    "role": "assistant",
                    "content": json.dumps(assistant_payload, ensure_ascii=False),
                },
            ],
        }
        records.append(record)

    return records


def write_sft_jsonl(records: Sequence[Mapping[str, object]], output_path: Path) -> Path:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    lines = [json.dumps(record, ensure_ascii=False, sort_keys=True) for record in records]
    output_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
    return output_path


def generate_dataset(output_path: Path = DEFAULT_OUTPUT_PATH, count: int = DEFAULT_COUNT) -> Path:
    return write_sft_jsonl(build_sft_records(count), output_path)


def _mode_for_index(example: Mapping[str, str], index: int) -> str:
    base_mode = example["mode"]
    base_index = PERSONALITY_MODES.index(base_mode) if base_mode in PERSONALITY_MODES else 0
    return PERSONALITY_MODES[(base_index + index) % len(PERSONALITY_MODES)]


def _description_for_index(example: Mapping[str, str], index: int) -> str:
    scene = SCENE_VARIANTS[index % len(SCENE_VARIANTS)]
    return f"{example['description']}, {scene}"


def _user_prompt(object_understanding: Mapping[str, object], mode: str) -> str:
    payload = json.dumps(object_understanding, ensure_ascii=False, sort_keys=True)
    return (
        f"Personality mode: {mode}\n"
        f"Object understanding JSON: {payload}\n"
        "Return JSON with keys persona and diary."
    )


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--count", type=int, default=DEFAULT_COUNT)
    parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
    return parser.parse_args()


def main() -> None:
    args = _parse_args()
    output_path = generate_dataset(args.output, args.count)
    print(f"wrote {args.count} SFT preview records to {output_path}")


if __name__ == "__main__":
    main()