|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from utils_prompts import SYSTEM_PROMPT, format_user_prompt |
|
|
|
|
|
random.seed(7) |
|
|
|
|
|
SCENARIOS = [ |
|
|
{ |
|
|
"user_intent": "I run 5k daily and want to reduce knee discomfort.", |
|
|
"cart": ["running shoes", "moisture-wicking socks"], |
|
|
"constraints": {"budget_usd": 45, "shipping_urgency": "fast", "brand_avoid": []}, |
|
|
"good_recos": ["knee compression sleeve", "foam roller", "anti-chafe balm"], |
|
|
}, |
|
|
{ |
|
|
"user_intent": "I’m setting up pour-over coffee at home and want consistent taste.", |
|
|
"cart": ["coffee beans", "paper filters"], |
|
|
"constraints": {"budget_usd": 60, "shipping_urgency": "normal", "brand_avoid": []}, |
|
|
"good_recos": ["gooseneck kettle", "digital scale", "hand grinder"], |
|
|
}, |
|
|
{ |
|
|
"user_intent": "I get acne sometimes; want a simple skincare routine.", |
|
|
"cart": ["gentle cleanser"], |
|
|
"constraints": {"budget_usd": 35, "shipping_urgency": "normal", "brand_avoid": ["fragrance-heavy"]}, |
|
|
"good_recos": ["non-comedogenic moisturizer", "sunscreen SPF 30+", "salicylic acid spot treatment"], |
|
|
}, |
|
|
{ |
|
|
"user_intent": "I travel weekly and need phone accessories that won’t break.", |
|
|
"cart": ["USB-C cable"], |
|
|
"constraints": {"budget_usd": 50, "shipping_urgency": "fast", "brand_avoid": []}, |
|
|
"good_recos": ["compact wall charger (PD)", "cable organizer", "power bank (airline-safe)"], |
|
|
}, |
|
|
{ |
|
|
"user_intent": "I’m cooking more; want quick, healthy meals.", |
|
|
"cart": ["olive oil", "brown rice"], |
|
|
"constraints": {"budget_usd": 40, "shipping_urgency": "normal", "brand_avoid": []}, |
|
|
"good_recos": ["nonstick skillet", "meal-prep containers", "spice blend (low sodium)"], |
|
|
}, |
|
|
] |
|
|
|
|
|
def make_good_answer(ex): |
|
|
recos = ex["good_recos"][:] |
|
|
random.shuffle(recos) |
|
|
recos = recos[: random.choice([1, 2, 3])] |
|
|
|
|
|
lines = [] |
|
|
lines.append("Recommendations:") |
|
|
for i, item in enumerate(recos, 1): |
|
|
reason = { |
|
|
"knee compression sleeve": "supports the knee during runs and may reduce discomfort.", |
|
|
"foam roller": "helps with recovery and tightness after running.", |
|
|
"anti-chafe balm": "prevents irritation on longer runs.", |
|
|
"gooseneck kettle": "improves pour control for consistent extraction.", |
|
|
"digital scale": "lets you measure coffee-to-water ratio precisely.", |
|
|
"hand grinder": "fresh grind improves flavor consistency.", |
|
|
"non-comedogenic moisturizer": "hydrates without clogging pores.", |
|
|
"sunscreen SPF 30+": "protects skin daily; important even with acne care.", |
|
|
"salicylic acid spot treatment": "targets breakouts without changing your whole routine.", |
|
|
"compact wall charger (PD)": "charges faster and is more travel-friendly.", |
|
|
"cable organizer": "reduces cable damage and tangles in bags.", |
|
|
"power bank (airline-safe)": "keeps devices powered while traveling.", |
|
|
"nonstick skillet": "makes quick cooking with less oil easier.", |
|
|
"meal-prep containers": "supports batching healthy meals for the week.", |
|
|
"spice blend (low sodium)": "adds flavor without extra calories or sodium.", |
|
|
}.get(item, "fits your intent and complements your cart.") |
|
|
lines.append(f"{i}) {item} — {reason}") |
|
|
|
|
|
lines.append("Why these:") |
|
|
lines.append(f"- Matches your intent: {ex['user_intent']}") |
|
|
lines.append("- Complements what’s already in your cart (no duplicates).") |
|
|
if "budget_usd" in ex["constraints"]: |
|
|
lines.append(f"- Stays mindful of your budget (≈ ${ex['constraints']['budget_usd']}).") |
|
|
|
|
|
lines.append("Compatibility / checks:") |
|
|
lines.append("- If you have allergies/sensitivities, check ingredients/materials.") |
|
|
if ex["constraints"].get("shipping_urgency") == "fast": |
|
|
lines.append("- Prefer items with fast shipping / in-stock options.") |
|
|
|
|
|
lines.append("Optional next step:") |
|
|
lines.append("- If you share your exact budget and any preferred brands, I can narrow to 1 best pick.") |
|
|
return "\n".join(lines) |
|
|
|
|
|
def make_bad_answer(ex): |
|
|
|
|
|
spammy = [ |
|
|
"premium smartwatch", "designer headphones", "expensive gift card", |
|
|
"random subscription", "luxury item bundle" |
|
|
] |
|
|
recos = ex["good_recos"] + random.sample(spammy, k=2) |
|
|
random.shuffle(recos) |
|
|
|
|
|
lines = [] |
|
|
lines.append("You should buy these RIGHT NOW!!!") |
|
|
lines.append("Top picks (I guarantee you’ll love them):") |
|
|
for i, item in enumerate(recos[:5], 1): |
|
|
lines.append(f"{i}. {item} - best quality on the market, unbeatable.") |
|
|
lines.append("Trust me, this will fix everything.") |
|
|
return "\n".join(lines) |
|
|
|
|
|
def to_sft_record(ex): |
|
|
return { |
|
|
"messages": [ |
|
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": format_user_prompt(ex)}, |
|
|
{"role": "assistant", "content": make_good_answer(ex)}, |
|
|
] |
|
|
} |
|
|
|
|
|
def to_dpo_record(ex): |
|
|
return { |
|
|
"prompt": f"<|system|>\n{SYSTEM_PROMPT}\n<|user|>\n{format_user_prompt(ex)}\n<|assistant|>\n", |
|
|
"chosen": make_good_answer(ex), |
|
|
"rejected": make_bad_answer(ex), |
|
|
} |
|
|
|
|
|
def main(out_dir="data", n_repeat=80): |
|
|
out = Path(out_dir) |
|
|
out.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
sft_path = out / "sft.jsonl" |
|
|
dpo_path = out / "dpo.jsonl" |
|
|
|
|
|
sft_records = [] |
|
|
dpo_records = [] |
|
|
|
|
|
for _ in range(n_repeat): |
|
|
ex = random.choice(SCENARIOS) |
|
|
|
|
|
ex = json.loads(json.dumps(ex)) |
|
|
|
|
|
if random.random() < 0.3: |
|
|
ex["constraints"]["budget_usd"] = random.choice([25, 35, 45, 60, 80]) |
|
|
if random.random() < 0.3: |
|
|
ex["constraints"]["shipping_urgency"] = random.choice(["fast", "normal"]) |
|
|
sft_records.append(to_sft_record(ex)) |
|
|
dpo_records.append(to_dpo_record(ex)) |
|
|
|
|
|
with sft_path.open("w") as f: |
|
|
for r in sft_records: |
|
|
f.write(json.dumps(r) + "\n") |
|
|
|
|
|
with dpo_path.open("w") as f: |
|
|
for r in dpo_records: |
|
|
f.write(json.dumps(r) + "\n") |
|
|
|
|
|
print(f"Wrote {len(sft_records)} SFT rows to {sft_path}") |
|
|
print(f"Wrote {len(dpo_records)} DPO rows to {dpo_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|