tobil commited on
Commit
98a3f77
·
verified ·
1 Parent(s): a330af8

Upload train_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo.py +292 -0
train_grpo.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.45.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # "datasets",
10
+ # "bitsandbytes",
11
+ # "sentence-transformers",
12
+ # ]
13
+ # ///
14
+ """
15
+ GRPO (Group Relative Policy Optimization) training for QMD query expansion.
16
+
17
+ Reward Type 2: Format + Diversity
18
+ - Rewards correct lex/vec/hyde format
19
+ - Penalizes repetition between lines
20
+ - Rewards semantic diversity of expansions
21
+
22
+ Usage:
23
+ uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
24
+ """
25
+
26
+ import re
27
+ import torch
28
+ import trackio
29
+ from datasets import load_dataset
30
+ from peft import LoraConfig, PeftModel
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer
32
+ from trl import GRPOTrainer, GRPOConfig
33
+ from sentence_transformers import SentenceTransformer
34
+
35
+ # ============================================================================
36
+ # Reward Function: Format + Diversity
37
+ # ============================================================================
38
+
39
+ def parse_expansion(text: str) -> dict:
40
+ """Parse expansion output into lex/vec/hyde components."""
41
+ result = {"lex": [], "vec": [], "hyde": []}
42
+
43
+ for line in text.strip().split("\n"):
44
+ line = line.strip()
45
+ if line.startswith("lex:"):
46
+ result["lex"].append(line[4:].strip())
47
+ elif line.startswith("vec:"):
48
+ result["vec"].append(line[4:].strip())
49
+ elif line.startswith("hyde:"):
50
+ result["hyde"].append(line[5:].strip())
51
+
52
+ return result
53
+
54
+
55
+ def compute_format_reward(text: str) -> float:
56
+ """
57
+ Reward for correct format:
58
+ - Has at least 1 lex line: +0.2
59
+ - Has at least 1 vec line: +0.2
60
+ - Has hyde line: +0.1
61
+ - Correct line format (type: content): +0.1 per line (max 0.3)
62
+ - No garbage/malformed lines: +0.2
63
+ """
64
+ reward = 0.0
65
+ parsed = parse_expansion(text)
66
+
67
+ # Check required components
68
+ if parsed["lex"]:
69
+ reward += 0.2
70
+ if parsed["vec"]:
71
+ reward += 0.2
72
+ if parsed["hyde"]:
73
+ reward += 0.1
74
+
75
+ # Check line format
76
+ lines = text.strip().split("\n")
77
+ valid_lines = 0
78
+ for line in lines:
79
+ if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
80
+ valid_lines += 1
81
+
82
+ reward += min(0.3, valid_lines * 0.1)
83
+
84
+ # Penalize malformed lines
85
+ malformed = len(lines) - valid_lines
86
+ if malformed == 0:
87
+ reward += 0.2
88
+ else:
89
+ reward -= malformed * 0.1
90
+
91
+ return max(0.0, min(1.0, reward))
92
+
93
+
94
+ def compute_diversity_reward(text: str, embedder) -> float:
95
+ """
96
+ Reward for diverse expansions:
97
+ - Penalize exact duplicates
98
+ - Reward semantic distance between expansions
99
+ """
100
+ parsed = parse_expansion(text)
101
+ all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
102
+
103
+ if len(all_expansions) < 2:
104
+ return 0.0
105
+
106
+ # Penalize exact duplicates
107
+ unique = set(e.lower() for e in all_expansions)
108
+ duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
109
+
110
+ # Compute semantic diversity
111
+ if len(unique) >= 2:
112
+ try:
113
+ embeddings = embedder.encode(list(unique))
114
+ # Compute pairwise cosine similarities
115
+ from torch.nn.functional import cosine_similarity
116
+ emb_tensor = torch.tensor(embeddings)
117
+
118
+ similarities = []
119
+ for i in range(len(emb_tensor)):
120
+ for j in range(i + 1, len(emb_tensor)):
121
+ sim = cosine_similarity(
122
+ emb_tensor[i].unsqueeze(0),
123
+ emb_tensor[j].unsqueeze(0)
124
+ ).item()
125
+ similarities.append(sim)
126
+
127
+ # Lower similarity = higher diversity = higher reward
128
+ avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
129
+ diversity_reward = 1.0 - avg_similarity # 0 = identical, 1 = orthogonal
130
+ except Exception:
131
+ diversity_reward = 0.0
132
+ else:
133
+ diversity_reward = 0.0
134
+
135
+ return max(0.0, diversity_reward - duplicate_penalty)
136
+
137
+
138
+ def compute_length_reward(text: str) -> float:
139
+ """Reward appropriate length (not too short, not too long)."""
140
+ lines = [l for l in text.strip().split("\n") if l.strip()]
141
+
142
+ # Ideal: 3-6 lines
143
+ if 3 <= len(lines) <= 6:
144
+ return 0.2
145
+ elif 2 <= len(lines) <= 7:
146
+ return 0.1
147
+ else:
148
+ return 0.0
149
+
150
+
151
+ class QMDRewardFunction:
152
+ """Combined reward function for QMD query expansion."""
153
+
154
+ def __init__(self):
155
+ # Load a small embedding model for diversity computation
156
+ print("Loading embedding model for diversity reward...")
157
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
158
+ print("Embedding model loaded.")
159
+
160
+ def __call__(self, completions: list[str], prompts: list[str] = None) -> list[float]:
161
+ """Compute rewards for a batch of completions."""
162
+ rewards = []
163
+
164
+ for completion in completions:
165
+ # Extract just the generated part (after prompt)
166
+ text = completion
167
+
168
+ # Compute component rewards
169
+ format_r = compute_format_reward(text)
170
+ diversity_r = compute_diversity_reward(text, self.embedder)
171
+ length_r = compute_length_reward(text)
172
+
173
+ # Weighted combination
174
+ total = (
175
+ 0.5 * format_r + # Format is most important
176
+ 0.35 * diversity_r + # Diversity is second
177
+ 0.15 * length_r # Length is minor
178
+ )
179
+
180
+ rewards.append(total)
181
+
182
+ return rewards
183
+
184
+
185
+ # ============================================================================
186
+ # Main Training
187
+ # ============================================================================
188
+
189
+ def main():
190
+ import argparse
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument("--sft-model", default="tobil/qmd-query-expansion-0.6B",
193
+ help="SFT model to use as starting point")
194
+ parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
195
+ help="Base model (for loading tokenizer)")
196
+ parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
197
+ help="Output model name on Hub")
198
+ parser.add_argument("--epochs", type=int, default=1)
199
+ parser.add_argument("--dry-run", action="store_true")
200
+ args = parser.parse_args()
201
+
202
+ if args.dry_run:
203
+ print("GRPO Training Config:")
204
+ print(f" SFT Model: {args.sft_model}")
205
+ print(f" Base Model: {args.base_model}")
206
+ print(f" Output: {args.output}")
207
+ print(f" Epochs: {args.epochs}")
208
+ return
209
+
210
+ # Load dataset (just prompts needed for GRPO)
211
+ print("Loading dataset...")
212
+ dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
213
+
214
+ # Extract just the queries as prompts
215
+ def extract_prompt(example):
216
+ return {"prompt": example["messages"][0]["content"]}
217
+
218
+ dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
219
+ dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset)))) # Use subset for GRPO
220
+ print(f"Using {len(dataset)} prompts for GRPO")
221
+
222
+ # Load tokenizer
223
+ print(f"Loading tokenizer from {args.base_model}...")
224
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model)
225
+ if tokenizer.pad_token is None:
226
+ tokenizer.pad_token = tokenizer.eos_token
227
+
228
+ # Load SFT model with LoRA adapter
229
+ print(f"Loading SFT model from {args.sft_model}...")
230
+ base_model = AutoModelForCausalLM.from_pretrained(
231
+ args.base_model,
232
+ torch_dtype=torch.bfloat16,
233
+ device_map="auto",
234
+ )
235
+ model = PeftModel.from_pretrained(base_model, args.sft_model)
236
+ model = model.merge_and_unload() # Merge LoRA weights
237
+ print("Model loaded and LoRA merged.")
238
+
239
+ # Initialize reward function
240
+ reward_fn = QMDRewardFunction()
241
+
242
+ # GRPO config
243
+ config = GRPOConfig(
244
+ output_dir="qmd-expansion-grpo",
245
+ push_to_hub=True,
246
+ hub_model_id=args.output,
247
+
248
+ # GRPO specific
249
+ num_generations=4, # Generate 4 completions per prompt
250
+ max_new_tokens=256,
251
+ temperature=0.8,
252
+
253
+ # Training
254
+ num_train_epochs=args.epochs,
255
+ per_device_train_batch_size=2,
256
+ gradient_accumulation_steps=4,
257
+ learning_rate=5e-6, # Lower LR for RL
258
+
259
+ # Logging
260
+ logging_steps=10,
261
+ save_strategy="epoch",
262
+
263
+ # Monitoring
264
+ report_to="trackio",
265
+ project="qmd-query-expansion-grpo",
266
+ run_name="grpo-format-diversity",
267
+ )
268
+
269
+ # Create trainer
270
+ print("Initializing GRPO trainer...")
271
+ trainer = GRPOTrainer(
272
+ model=model,
273
+ tokenizer=tokenizer,
274
+ config=config,
275
+ train_dataset=dataset,
276
+ reward_funcs=reward_fn,
277
+ )
278
+
279
+ # Train
280
+ print("Starting GRPO training...")
281
+ trainer.train()
282
+
283
+ # Save
284
+ print("Pushing to Hub...")
285
+ trainer.push_to_hub()
286
+
287
+ trackio.finish()
288
+ print(f"Done! Model at: https://huggingface.co/{args.output}")
289
+
290
+
291
+ if __name__ == "__main__":
292
+ main()