Fahad-sha commited on
Commit
a365d48
·
verified ·
1 Parent(s): bd07c7a

Upload 5 files

Browse files
trainer/dpo_train.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
4
+ from peft import LoraConfig
5
+ from trl import DPOTrainer
6
+
7
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
8
+ # Optionally start from SFT adapter by setting MODEL_ADAPTER=adapter_sft (not used in this minimal version)
9
+ MODEL_ADAPTER = os.environ.get("MODEL_ADAPTER", "")
10
+
11
+ def main():
12
+ ds = load_dataset("json", data_files="data/dpo.jsonl")["train"]
13
+
14
+ tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
15
+ if tok.pad_token is None:
16
+ tok.pad_token = tok.eos_token
17
+
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ BASE_MODEL,
20
+ device_map="auto",
21
+ torch_dtype="auto",
22
+ )
23
+
24
+ peft_cfg = LoraConfig(
25
+ r=16, lora_alpha=32, lora_dropout=0.05,
26
+ bias="none", task_type="CAUSAL_LM",
27
+ target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
28
+ )
29
+
30
+ args = TrainingArguments(
31
+ output_dir="adapter_dpo",
32
+ per_device_train_batch_size=2,
33
+ gradient_accumulation_steps=8,
34
+ learning_rate=5e-5,
35
+ num_train_epochs=1,
36
+ logging_steps=20,
37
+ save_steps=200,
38
+ fp16=True,
39
+ report_to="none"
40
+ )
41
+
42
+ trainer = DPOTrainer(
43
+ model=model,
44
+ ref_model=None,
45
+ args=args,
46
+ train_dataset=ds,
47
+ tokenizer=tok,
48
+ peft_config=peft_cfg,
49
+ beta=0.1,
50
+ max_length=1024,
51
+ max_prompt_length=512,
52
+ )
53
+
54
+ trainer.train()
55
+ trainer.save_model("adapter_dpo")
56
+ tok.save_pretrained("adapter_dpo")
57
+ print("Saved adapter_dpo/")
58
+
59
+ if __name__ == "__main__":
60
+ main()
trainer/make_data.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from pathlib import Path
4
+ from utils_prompts import SYSTEM_PROMPT, format_user_prompt
5
+
6
+ random.seed(7)
7
+
8
+ SCENARIOS = [
9
+ {
10
+ "user_intent": "I run 5k daily and want to reduce knee discomfort.",
11
+ "cart": ["running shoes", "moisture-wicking socks"],
12
+ "constraints": {"budget_usd": 45, "shipping_urgency": "fast", "brand_avoid": []},
13
+ "good_recos": ["knee compression sleeve", "foam roller", "anti-chafe balm"],
14
+ },
15
+ {
16
+ "user_intent": "I’m setting up pour-over coffee at home and want consistent taste.",
17
+ "cart": ["coffee beans", "paper filters"],
18
+ "constraints": {"budget_usd": 60, "shipping_urgency": "normal", "brand_avoid": []},
19
+ "good_recos": ["gooseneck kettle", "digital scale", "hand grinder"],
20
+ },
21
+ {
22
+ "user_intent": "I get acne sometimes; want a simple skincare routine.",
23
+ "cart": ["gentle cleanser"],
24
+ "constraints": {"budget_usd": 35, "shipping_urgency": "normal", "brand_avoid": ["fragrance-heavy"]},
25
+ "good_recos": ["non-comedogenic moisturizer", "sunscreen SPF 30+", "salicylic acid spot treatment"],
26
+ },
27
+ {
28
+ "user_intent": "I travel weekly and need phone accessories that won’t break.",
29
+ "cart": ["USB-C cable"],
30
+ "constraints": {"budget_usd": 50, "shipping_urgency": "fast", "brand_avoid": []},
31
+ "good_recos": ["compact wall charger (PD)", "cable organizer", "power bank (airline-safe)"],
32
+ },
33
+ {
34
+ "user_intent": "I’m cooking more; want quick, healthy meals.",
35
+ "cart": ["olive oil", "brown rice"],
36
+ "constraints": {"budget_usd": 40, "shipping_urgency": "normal", "brand_avoid": []},
37
+ "good_recos": ["nonstick skillet", "meal-prep containers", "spice blend (low sodium)"],
38
+ },
39
+ ]
40
+
41
+ def make_good_answer(ex):
42
+ recos = ex["good_recos"][:]
43
+ random.shuffle(recos)
44
+ recos = recos[: random.choice([1, 2, 3])]
45
+
46
+ lines = []
47
+ lines.append("Recommendations:")
48
+ for i, item in enumerate(recos, 1):
49
+ reason = {
50
+ "knee compression sleeve": "supports the knee during runs and may reduce discomfort.",
51
+ "foam roller": "helps with recovery and tightness after running.",
52
+ "anti-chafe balm": "prevents irritation on longer runs.",
53
+ "gooseneck kettle": "improves pour control for consistent extraction.",
54
+ "digital scale": "lets you measure coffee-to-water ratio precisely.",
55
+ "hand grinder": "fresh grind improves flavor consistency.",
56
+ "non-comedogenic moisturizer": "hydrates without clogging pores.",
57
+ "sunscreen SPF 30+": "protects skin daily; important even with acne care.",
58
+ "salicylic acid spot treatment": "targets breakouts without changing your whole routine.",
59
+ "compact wall charger (PD)": "charges faster and is more travel-friendly.",
60
+ "cable organizer": "reduces cable damage and tangles in bags.",
61
+ "power bank (airline-safe)": "keeps devices powered while traveling.",
62
+ "nonstick skillet": "makes quick cooking with less oil easier.",
63
+ "meal-prep containers": "supports batching healthy meals for the week.",
64
+ "spice blend (low sodium)": "adds flavor without extra calories or sodium.",
65
+ }.get(item, "fits your intent and complements your cart.")
66
+ lines.append(f"{i}) {item} — {reason}")
67
+
68
+ lines.append("Why these:")
69
+ lines.append(f"- Matches your intent: {ex['user_intent']}")
70
+ lines.append("- Complements what’s already in your cart (no duplicates).")
71
+ if "budget_usd" in ex["constraints"]:
72
+ lines.append(f"- Stays mindful of your budget (≈ ${ex['constraints']['budget_usd']}).")
73
+
74
+ lines.append("Compatibility / checks:")
75
+ lines.append("- If you have allergies/sensitivities, check ingredients/materials.")
76
+ if ex["constraints"].get("shipping_urgency") == "fast":
77
+ lines.append("- Prefer items with fast shipping / in-stock options.")
78
+
79
+ lines.append("Optional next step:")
80
+ lines.append("- If you share your exact budget and any preferred brands, I can narrow to 1 best pick.")
81
+ return "\n".join(lines)
82
+
83
+ def make_bad_answer(ex):
84
+ # Intentionally violates rubric: too many items, salesy, ignores constraints, vague.
85
+ spammy = [
86
+ "premium smartwatch", "designer headphones", "expensive gift card",
87
+ "random subscription", "luxury item bundle"
88
+ ]
89
+ recos = ex["good_recos"] + random.sample(spammy, k=2)
90
+ random.shuffle(recos)
91
+
92
+ lines = []
93
+ lines.append("You should buy these RIGHT NOW!!!")
94
+ lines.append("Top picks (I guarantee you’ll love them):")
95
+ for i, item in enumerate(recos[:5], 1):
96
+ lines.append(f"{i}. {item} - best quality on the market, unbeatable.")
97
+ lines.append("Trust me, this will fix everything.")
98
+ return "\n".join(lines)
99
+
100
+ def to_sft_record(ex):
101
+ return {
102
+ "messages": [
103
+ {"role": "system", "content": SYSTEM_PROMPT},
104
+ {"role": "user", "content": format_user_prompt(ex)},
105
+ {"role": "assistant", "content": make_good_answer(ex)},
106
+ ]
107
+ }
108
+
109
+ def to_dpo_record(ex):
110
+ return {
111
+ "prompt": f"<|system|>\n{SYSTEM_PROMPT}\n<|user|>\n{format_user_prompt(ex)}\n<|assistant|>\n",
112
+ "chosen": make_good_answer(ex),
113
+ "rejected": make_bad_answer(ex),
114
+ }
115
+
116
+ def main(out_dir="data", n_repeat=80):
117
+ out = Path(out_dir)
118
+ out.mkdir(parents=True, exist_ok=True)
119
+
120
+ sft_path = out / "sft.jsonl"
121
+ dpo_path = out / "dpo.jsonl"
122
+
123
+ sft_records = []
124
+ dpo_records = []
125
+
126
+ for _ in range(n_repeat):
127
+ ex = random.choice(SCENARIOS)
128
+ # light randomization
129
+ ex = json.loads(json.dumps(ex))
130
+ # occasionally tweak budget/urgency
131
+ if random.random() < 0.3:
132
+ ex["constraints"]["budget_usd"] = random.choice([25, 35, 45, 60, 80])
133
+ if random.random() < 0.3:
134
+ ex["constraints"]["shipping_urgency"] = random.choice(["fast", "normal"])
135
+ sft_records.append(to_sft_record(ex))
136
+ dpo_records.append(to_dpo_record(ex))
137
+
138
+ with sft_path.open("w") as f:
139
+ for r in sft_records:
140
+ f.write(json.dumps(r) + "\n")
141
+
142
+ with dpo_path.open("w") as f:
143
+ for r in dpo_records:
144
+ f.write(json.dumps(r) + "\n")
145
+
146
+ print(f"Wrote {len(sft_records)} SFT rows to {sft_path}")
147
+ print(f"Wrote {len(dpo_records)} DPO rows to {dpo_path}")
148
+
149
+ if __name__ == "__main__":
150
+ main()
trainer/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.42.0
3
+ datasets
4
+ accelerate
5
+ peft
6
+ trl
7
+ bitsandbytes
trainer/sft_train.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
4
+ from trl import SFTTrainer
5
+ from peft import LoraConfig
6
+
7
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
8
+
9
+ def main():
10
+ ds = load_dataset("json", data_files="data/sft.jsonl")["train"]
11
+
12
+ tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
13
+ if tok.pad_token is None:
14
+ tok.pad_token = tok.eos_token
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ BASE_MODEL,
18
+ device_map="auto",
19
+ torch_dtype="auto",
20
+ )
21
+
22
+ peft_cfg = LoraConfig(
23
+ r=16, lora_alpha=32, lora_dropout=0.05,
24
+ bias="none", task_type="CAUSAL_LM",
25
+ target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
26
+ )
27
+
28
+ args = TrainingArguments(
29
+ output_dir="adapter_sft",
30
+ per_device_train_batch_size=2,
31
+ gradient_accumulation_steps=8,
32
+ learning_rate=2e-4,
33
+ num_train_epochs=1,
34
+ logging_steps=20,
35
+ save_steps=200,
36
+ fp16=True,
37
+ report_to="none"
38
+ )
39
+
40
+ trainer = SFTTrainer(
41
+ model=model,
42
+ tokenizer=tok,
43
+ train_dataset=ds,
44
+ peft_config=peft_cfg,
45
+ max_seq_length=1024,
46
+ args=args,
47
+ packing=False,
48
+ dataset_text_field=None, # because we use "messages"
49
+ )
50
+ trainer.train()
51
+ trainer.save_model("adapter_sft")
52
+ tok.save_pretrained("adapter_sft")
53
+ print("Saved adapter_sft/")
54
+
55
+ if __name__ == "__main__":
56
+ main()
trainer/utils_prompts.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM_PROMPT = """You are a retail recommendation assistant.
2
+ You recommend at most 3 items that complement the user's cart and intent.
3
+ You must be:
4
+ - Relevant to the cart + intent
5
+ - Constraint-aware (budget, urgency, compatibility, brand preferences)
6
+ - Non-pushy and honest (no made-up specs or guarantees)
7
+ - Concise and structured
8
+
9
+ Output format:
10
+ Recommendations:
11
+ 1) <item> — <one-line reason>
12
+ 2) ...
13
+ Why these:
14
+ - ...
15
+ Compatibility / checks:
16
+ - ...
17
+ Optional next step:
18
+ - (only if helpful)
19
+ """
20
+
21
+ def format_user_prompt(example: dict) -> str:
22
+ return (
23
+ f"User intent: {example['user_intent']}\n"
24
+ f"Cart: {', '.join(example['cart'])}\n"
25
+ f"Constraints: {example.get('constraints', {})}\n"
26
+ "Generate recommendations following the required format."
27
+ )