Spaces:
Running on Zero
Running on Zero
File size: 4,717 Bytes
bc02199 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Generate deterministic SFT preview data for Objectverse Diary."""
from __future__ import annotations
import argparse
import json
import sys
from collections.abc import Mapping, Sequence
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.config import PERSONALITY_MODES
from src.examples import EXAMPLE_OBJECTS
from src.pipeline import generate_object_diary
DEFAULT_OUTPUT_PATH = Path("data/train/objectverse_sft_preview.jsonl")
DEFAULT_COUNT = 60
SYSTEM_PROMPT = (
"You are Objectverse Diary, an English-first small-model assistant. "
"Given structured object understanding and a requested personality mode, "
"return JSON containing a consistent hidden object persona and a short diary."
)
SCENE_VARIANTS = [
"after a long day of being ignored",
"beside half-finished notes",
"under warm desk light",
"near a quiet morning window",
"during a suspiciously ordinary afternoon",
"while humans rush past without noticing",
"after surviving another minor household crisis",
"beside a charging cable and old receipts",
"in the corner of a room that knows too much",
"before anyone remembers to clean it",
]
def build_sft_records(count: int = DEFAULT_COUNT) -> list[dict[str, object]]:
if count < 1:
raise ValueError("count must be at least 1")
records: list[dict[str, object]] = []
for index in range(count):
example = EXAMPLE_OBJECTS[index % len(EXAMPLE_OBJECTS)]
mode = _mode_for_index(example, index)
description = _description_for_index(example, index)
result = generate_object_diary(
image_path=None,
description=description,
mode=mode,
save=False,
trace_id=f"sft-preview-{index + 1:04d}",
)
assistant_payload = {
"persona": result.persona.persona.model_dump(mode="json"),
"diary": result.diary.model_dump(mode="json"),
}
record = {
"id": f"sft-preview-{index + 1:04d}",
"source": "objectverse-diary-mock-mvp",
"split": "preview",
"mode": mode,
"object_description": description,
"object_understanding": result.object_understanding.model_dump(mode="json"),
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": _user_prompt(
result.object_understanding.model_dump(mode="json"),
mode,
),
},
{
"role": "assistant",
"content": json.dumps(assistant_payload, ensure_ascii=False),
},
],
}
records.append(record)
return records
def write_sft_jsonl(records: Sequence[Mapping[str, object]], output_path: Path) -> Path:
output_path.parent.mkdir(parents=True, exist_ok=True)
lines = [json.dumps(record, ensure_ascii=False, sort_keys=True) for record in records]
output_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
return output_path
def generate_dataset(output_path: Path = DEFAULT_OUTPUT_PATH, count: int = DEFAULT_COUNT) -> Path:
return write_sft_jsonl(build_sft_records(count), output_path)
def _mode_for_index(example: Mapping[str, str], index: int) -> str:
base_mode = example["mode"]
base_index = PERSONALITY_MODES.index(base_mode) if base_mode in PERSONALITY_MODES else 0
return PERSONALITY_MODES[(base_index + index) % len(PERSONALITY_MODES)]
def _description_for_index(example: Mapping[str, str], index: int) -> str:
scene = SCENE_VARIANTS[index % len(SCENE_VARIANTS)]
return f"{example['description']}, {scene}"
def _user_prompt(object_understanding: Mapping[str, object], mode: str) -> str:
payload = json.dumps(object_understanding, ensure_ascii=False, sort_keys=True)
return (
f"Personality mode: {mode}\n"
f"Object understanding JSON: {payload}\n"
"Return JSON with keys persona and diary."
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--count", type=int, default=DEFAULT_COUNT)
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT_PATH)
return parser.parse_args()
def main() -> None:
args = _parse_args()
output_path = generate_dataset(args.output, args.count)
print(f"wrote {args.count} SFT preview records to {output_path}")
if __name__ == "__main__":
main()
|