Upload 5 files
Browse files- trainer/dpo_train.py +60 -0
- trainer/make_data.py +150 -0
- trainer/requirements.txt +7 -0
- trainer/sft_train.py +56 -0
- trainer/utils_prompts.py +27 -0
trainer/dpo_train.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 4 |
+
from peft import LoraConfig
|
| 5 |
+
from trl import DPOTrainer
|
| 6 |
+
|
| 7 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 8 |
+
# Optionally start from SFT adapter by setting MODEL_ADAPTER=adapter_sft (not used in this minimal version)
|
| 9 |
+
MODEL_ADAPTER = os.environ.get("MODEL_ADAPTER", "")
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
ds = load_dataset("json", data_files="data/dpo.jsonl")["train"]
|
| 13 |
+
|
| 14 |
+
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
| 15 |
+
if tok.pad_token is None:
|
| 16 |
+
tok.pad_token = tok.eos_token
|
| 17 |
+
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 19 |
+
BASE_MODEL,
|
| 20 |
+
device_map="auto",
|
| 21 |
+
torch_dtype="auto",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
peft_cfg = LoraConfig(
|
| 25 |
+
r=16, lora_alpha=32, lora_dropout=0.05,
|
| 26 |
+
bias="none", task_type="CAUSAL_LM",
|
| 27 |
+
target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
args = TrainingArguments(
|
| 31 |
+
output_dir="adapter_dpo",
|
| 32 |
+
per_device_train_batch_size=2,
|
| 33 |
+
gradient_accumulation_steps=8,
|
| 34 |
+
learning_rate=5e-5,
|
| 35 |
+
num_train_epochs=1,
|
| 36 |
+
logging_steps=20,
|
| 37 |
+
save_steps=200,
|
| 38 |
+
fp16=True,
|
| 39 |
+
report_to="none"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
trainer = DPOTrainer(
|
| 43 |
+
model=model,
|
| 44 |
+
ref_model=None,
|
| 45 |
+
args=args,
|
| 46 |
+
train_dataset=ds,
|
| 47 |
+
tokenizer=tok,
|
| 48 |
+
peft_config=peft_cfg,
|
| 49 |
+
beta=0.1,
|
| 50 |
+
max_length=1024,
|
| 51 |
+
max_prompt_length=512,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
trainer.train()
|
| 55 |
+
trainer.save_model("adapter_dpo")
|
| 56 |
+
tok.save_pretrained("adapter_dpo")
|
| 57 |
+
print("Saved adapter_dpo/")
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
trainer/make_data.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from utils_prompts import SYSTEM_PROMPT, format_user_prompt
|
| 5 |
+
|
| 6 |
+
random.seed(7)
|
| 7 |
+
|
| 8 |
+
SCENARIOS = [
|
| 9 |
+
{
|
| 10 |
+
"user_intent": "I run 5k daily and want to reduce knee discomfort.",
|
| 11 |
+
"cart": ["running shoes", "moisture-wicking socks"],
|
| 12 |
+
"constraints": {"budget_usd": 45, "shipping_urgency": "fast", "brand_avoid": []},
|
| 13 |
+
"good_recos": ["knee compression sleeve", "foam roller", "anti-chafe balm"],
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"user_intent": "I’m setting up pour-over coffee at home and want consistent taste.",
|
| 17 |
+
"cart": ["coffee beans", "paper filters"],
|
| 18 |
+
"constraints": {"budget_usd": 60, "shipping_urgency": "normal", "brand_avoid": []},
|
| 19 |
+
"good_recos": ["gooseneck kettle", "digital scale", "hand grinder"],
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"user_intent": "I get acne sometimes; want a simple skincare routine.",
|
| 23 |
+
"cart": ["gentle cleanser"],
|
| 24 |
+
"constraints": {"budget_usd": 35, "shipping_urgency": "normal", "brand_avoid": ["fragrance-heavy"]},
|
| 25 |
+
"good_recos": ["non-comedogenic moisturizer", "sunscreen SPF 30+", "salicylic acid spot treatment"],
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"user_intent": "I travel weekly and need phone accessories that won’t break.",
|
| 29 |
+
"cart": ["USB-C cable"],
|
| 30 |
+
"constraints": {"budget_usd": 50, "shipping_urgency": "fast", "brand_avoid": []},
|
| 31 |
+
"good_recos": ["compact wall charger (PD)", "cable organizer", "power bank (airline-safe)"],
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"user_intent": "I’m cooking more; want quick, healthy meals.",
|
| 35 |
+
"cart": ["olive oil", "brown rice"],
|
| 36 |
+
"constraints": {"budget_usd": 40, "shipping_urgency": "normal", "brand_avoid": []},
|
| 37 |
+
"good_recos": ["nonstick skillet", "meal-prep containers", "spice blend (low sodium)"],
|
| 38 |
+
},
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def make_good_answer(ex):
|
| 42 |
+
recos = ex["good_recos"][:]
|
| 43 |
+
random.shuffle(recos)
|
| 44 |
+
recos = recos[: random.choice([1, 2, 3])]
|
| 45 |
+
|
| 46 |
+
lines = []
|
| 47 |
+
lines.append("Recommendations:")
|
| 48 |
+
for i, item in enumerate(recos, 1):
|
| 49 |
+
reason = {
|
| 50 |
+
"knee compression sleeve": "supports the knee during runs and may reduce discomfort.",
|
| 51 |
+
"foam roller": "helps with recovery and tightness after running.",
|
| 52 |
+
"anti-chafe balm": "prevents irritation on longer runs.",
|
| 53 |
+
"gooseneck kettle": "improves pour control for consistent extraction.",
|
| 54 |
+
"digital scale": "lets you measure coffee-to-water ratio precisely.",
|
| 55 |
+
"hand grinder": "fresh grind improves flavor consistency.",
|
| 56 |
+
"non-comedogenic moisturizer": "hydrates without clogging pores.",
|
| 57 |
+
"sunscreen SPF 30+": "protects skin daily; important even with acne care.",
|
| 58 |
+
"salicylic acid spot treatment": "targets breakouts without changing your whole routine.",
|
| 59 |
+
"compact wall charger (PD)": "charges faster and is more travel-friendly.",
|
| 60 |
+
"cable organizer": "reduces cable damage and tangles in bags.",
|
| 61 |
+
"power bank (airline-safe)": "keeps devices powered while traveling.",
|
| 62 |
+
"nonstick skillet": "makes quick cooking with less oil easier.",
|
| 63 |
+
"meal-prep containers": "supports batching healthy meals for the week.",
|
| 64 |
+
"spice blend (low sodium)": "adds flavor without extra calories or sodium.",
|
| 65 |
+
}.get(item, "fits your intent and complements your cart.")
|
| 66 |
+
lines.append(f"{i}) {item} — {reason}")
|
| 67 |
+
|
| 68 |
+
lines.append("Why these:")
|
| 69 |
+
lines.append(f"- Matches your intent: {ex['user_intent']}")
|
| 70 |
+
lines.append("- Complements what’s already in your cart (no duplicates).")
|
| 71 |
+
if "budget_usd" in ex["constraints"]:
|
| 72 |
+
lines.append(f"- Stays mindful of your budget (≈ ${ex['constraints']['budget_usd']}).")
|
| 73 |
+
|
| 74 |
+
lines.append("Compatibility / checks:")
|
| 75 |
+
lines.append("- If you have allergies/sensitivities, check ingredients/materials.")
|
| 76 |
+
if ex["constraints"].get("shipping_urgency") == "fast":
|
| 77 |
+
lines.append("- Prefer items with fast shipping / in-stock options.")
|
| 78 |
+
|
| 79 |
+
lines.append("Optional next step:")
|
| 80 |
+
lines.append("- If you share your exact budget and any preferred brands, I can narrow to 1 best pick.")
|
| 81 |
+
return "\n".join(lines)
|
| 82 |
+
|
| 83 |
+
def make_bad_answer(ex):
|
| 84 |
+
# Intentionally violates rubric: too many items, salesy, ignores constraints, vague.
|
| 85 |
+
spammy = [
|
| 86 |
+
"premium smartwatch", "designer headphones", "expensive gift card",
|
| 87 |
+
"random subscription", "luxury item bundle"
|
| 88 |
+
]
|
| 89 |
+
recos = ex["good_recos"] + random.sample(spammy, k=2)
|
| 90 |
+
random.shuffle(recos)
|
| 91 |
+
|
| 92 |
+
lines = []
|
| 93 |
+
lines.append("You should buy these RIGHT NOW!!!")
|
| 94 |
+
lines.append("Top picks (I guarantee you’ll love them):")
|
| 95 |
+
for i, item in enumerate(recos[:5], 1):
|
| 96 |
+
lines.append(f"{i}. {item} - best quality on the market, unbeatable.")
|
| 97 |
+
lines.append("Trust me, this will fix everything.")
|
| 98 |
+
return "\n".join(lines)
|
| 99 |
+
|
| 100 |
+
def to_sft_record(ex):
|
| 101 |
+
return {
|
| 102 |
+
"messages": [
|
| 103 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 104 |
+
{"role": "user", "content": format_user_prompt(ex)},
|
| 105 |
+
{"role": "assistant", "content": make_good_answer(ex)},
|
| 106 |
+
]
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def to_dpo_record(ex):
|
| 110 |
+
return {
|
| 111 |
+
"prompt": f"<|system|>\n{SYSTEM_PROMPT}\n<|user|>\n{format_user_prompt(ex)}\n<|assistant|>\n",
|
| 112 |
+
"chosen": make_good_answer(ex),
|
| 113 |
+
"rejected": make_bad_answer(ex),
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def main(out_dir="data", n_repeat=80):
|
| 117 |
+
out = Path(out_dir)
|
| 118 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
sft_path = out / "sft.jsonl"
|
| 121 |
+
dpo_path = out / "dpo.jsonl"
|
| 122 |
+
|
| 123 |
+
sft_records = []
|
| 124 |
+
dpo_records = []
|
| 125 |
+
|
| 126 |
+
for _ in range(n_repeat):
|
| 127 |
+
ex = random.choice(SCENARIOS)
|
| 128 |
+
# light randomization
|
| 129 |
+
ex = json.loads(json.dumps(ex))
|
| 130 |
+
# occasionally tweak budget/urgency
|
| 131 |
+
if random.random() < 0.3:
|
| 132 |
+
ex["constraints"]["budget_usd"] = random.choice([25, 35, 45, 60, 80])
|
| 133 |
+
if random.random() < 0.3:
|
| 134 |
+
ex["constraints"]["shipping_urgency"] = random.choice(["fast", "normal"])
|
| 135 |
+
sft_records.append(to_sft_record(ex))
|
| 136 |
+
dpo_records.append(to_dpo_record(ex))
|
| 137 |
+
|
| 138 |
+
with sft_path.open("w") as f:
|
| 139 |
+
for r in sft_records:
|
| 140 |
+
f.write(json.dumps(r) + "\n")
|
| 141 |
+
|
| 142 |
+
with dpo_path.open("w") as f:
|
| 143 |
+
for r in dpo_records:
|
| 144 |
+
f.write(json.dumps(r) + "\n")
|
| 145 |
+
|
| 146 |
+
print(f"Wrote {len(sft_records)} SFT rows to {sft_path}")
|
| 147 |
+
print(f"Wrote {len(dpo_records)} DPO rows to {dpo_path}")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
main()
|
trainer/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers>=4.42.0
|
| 3 |
+
datasets
|
| 4 |
+
accelerate
|
| 5 |
+
peft
|
| 6 |
+
trl
|
| 7 |
+
bitsandbytes
|
trainer/sft_train.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 4 |
+
from trl import SFTTrainer
|
| 5 |
+
from peft import LoraConfig
|
| 6 |
+
|
| 7 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
ds = load_dataset("json", data_files="data/sft.jsonl")["train"]
|
| 11 |
+
|
| 12 |
+
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
| 13 |
+
if tok.pad_token is None:
|
| 14 |
+
tok.pad_token = tok.eos_token
|
| 15 |
+
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
+
BASE_MODEL,
|
| 18 |
+
device_map="auto",
|
| 19 |
+
torch_dtype="auto",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
peft_cfg = LoraConfig(
|
| 23 |
+
r=16, lora_alpha=32, lora_dropout=0.05,
|
| 24 |
+
bias="none", task_type="CAUSAL_LM",
|
| 25 |
+
target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
args = TrainingArguments(
|
| 29 |
+
output_dir="adapter_sft",
|
| 30 |
+
per_device_train_batch_size=2,
|
| 31 |
+
gradient_accumulation_steps=8,
|
| 32 |
+
learning_rate=2e-4,
|
| 33 |
+
num_train_epochs=1,
|
| 34 |
+
logging_steps=20,
|
| 35 |
+
save_steps=200,
|
| 36 |
+
fp16=True,
|
| 37 |
+
report_to="none"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
trainer = SFTTrainer(
|
| 41 |
+
model=model,
|
| 42 |
+
tokenizer=tok,
|
| 43 |
+
train_dataset=ds,
|
| 44 |
+
peft_config=peft_cfg,
|
| 45 |
+
max_seq_length=1024,
|
| 46 |
+
args=args,
|
| 47 |
+
packing=False,
|
| 48 |
+
dataset_text_field=None, # because we use "messages"
|
| 49 |
+
)
|
| 50 |
+
trainer.train()
|
| 51 |
+
trainer.save_model("adapter_sft")
|
| 52 |
+
tok.save_pretrained("adapter_sft")
|
| 53 |
+
print("Saved adapter_sft/")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
main()
|
trainer/utils_prompts.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SYSTEM_PROMPT = """You are a retail recommendation assistant.
|
| 2 |
+
You recommend at most 3 items that complement the user's cart and intent.
|
| 3 |
+
You must be:
|
| 4 |
+
- Relevant to the cart + intent
|
| 5 |
+
- Constraint-aware (budget, urgency, compatibility, brand preferences)
|
| 6 |
+
- Non-pushy and honest (no made-up specs or guarantees)
|
| 7 |
+
- Concise and structured
|
| 8 |
+
|
| 9 |
+
Output format:
|
| 10 |
+
Recommendations:
|
| 11 |
+
1) <item> — <one-line reason>
|
| 12 |
+
2) ...
|
| 13 |
+
Why these:
|
| 14 |
+
- ...
|
| 15 |
+
Compatibility / checks:
|
| 16 |
+
- ...
|
| 17 |
+
Optional next step:
|
| 18 |
+
- (only if helpful)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def format_user_prompt(example: dict) -> str:
|
| 22 |
+
return (
|
| 23 |
+
f"User intent: {example['user_intent']}\n"
|
| 24 |
+
f"Cart: {', '.join(example['cart'])}\n"
|
| 25 |
+
f"Constraints: {example.get('constraints', {})}\n"
|
| 26 |
+
"Generate recommendations following the required format."
|
| 27 |
+
)
|