tobil commited on
Commit
c7967b0
·
verified ·
1 Parent(s): 58867a4

Upload train_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo.py +188 -133
train_grpo.py CHANGED
@@ -8,16 +8,16 @@
8
  # "trackio",
9
  # "datasets",
10
  # "bitsandbytes",
11
- # "sentence-transformers",
12
  # ]
13
  # ///
14
  """
15
  GRPO (Group Relative Policy Optimization) training for QMD query expansion.
16
 
17
- Reward Type 2: Format + Diversity
18
- - Rewards correct lex/vec/hyde format
19
- - Penalizes repetition between lines
20
- - Rewards semantic diversity of expansions
 
21
 
22
  Usage:
23
  uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
@@ -26,159 +26,204 @@ Usage:
26
  import re
27
  import torch
28
  import trackio
 
29
  from datasets import load_dataset
30
- from peft import LoraConfig, PeftModel
31
  from transformers import AutoModelForCausalLM, AutoTokenizer
32
  from trl import GRPOTrainer, GRPOConfig
33
- from sentence_transformers import SentenceTransformer
 
34
 
35
  # ============================================================================
36
- # Reward Function: Format + Diversity
37
  # ============================================================================
38
 
39
  def parse_expansion(text: str) -> dict:
40
- """Parse expansion output into lex/vec/hyde components."""
41
- result = {"lex": [], "vec": [], "hyde": []}
 
42
 
43
- for line in text.strip().split("\n"):
44
  line = line.strip()
 
 
45
  if line.startswith("lex:"):
46
  result["lex"].append(line[4:].strip())
47
  elif line.startswith("vec:"):
48
  result["vec"].append(line[4:].strip())
49
  elif line.startswith("hyde:"):
50
  result["hyde"].append(line[5:].strip())
 
 
51
 
52
  return result
53
 
54
 
55
- def compute_format_reward(text: str) -> float:
56
- """
57
- Reward for correct format:
58
- - Has at least 1 lex line: +0.2
59
- - Has at least 1 vec line: +0.2
60
- - Has hyde line: +0.1
61
- - Correct line format (type: content): +0.1 per line (max 0.3)
62
- - No garbage/malformed lines: +0.2
63
- """
64
- reward = 0.0
65
- parsed = parse_expansion(text)
66
 
67
- # Check required components
68
- if parsed["lex"]:
69
- reward += 0.2
70
- if parsed["vec"]:
71
- reward += 0.2
72
- if parsed["hyde"]:
73
- reward += 0.1
74
 
75
- # Check line format
76
- lines = text.strip().split("\n")
77
- valid_lines = 0
78
- for line in lines:
79
- if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
80
- valid_lines += 1
 
 
81
 
82
- reward += min(0.3, valid_lines * 0.1)
83
 
84
- # Penalize malformed lines
85
- malformed = len(lines) - valid_lines
86
- if malformed == 0:
87
- reward += 0.2
88
- else:
89
- reward -= malformed * 0.1
 
 
 
90
 
91
- return max(0.0, min(1.0, reward))
92
 
 
 
 
 
 
 
 
 
 
93
 
94
- def compute_diversity_reward(text: str, embedder) -> float:
 
95
  """
96
- Reward for diverse expansions:
97
- - Penalize exact duplicates
98
- - Reward semantic distance between expansions
99
  """
100
- parsed = parse_expansion(text)
101
- all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
102
-
103
- if len(all_expansions) < 2:
104
- return 0.0
105
-
106
- # Penalize exact duplicates
107
- unique = set(e.lower() for e in all_expansions)
108
- duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
109
-
110
- # Compute semantic diversity
111
- if len(unique) >= 2:
112
- try:
113
- embeddings = embedder.encode(list(unique))
114
- # Compute pairwise cosine similarities
115
- from torch.nn.functional import cosine_similarity
116
- emb_tensor = torch.tensor(embeddings)
117
-
118
- similarities = []
119
- for i in range(len(emb_tensor)):
120
- for j in range(i + 1, len(emb_tensor)):
121
- sim = cosine_similarity(
122
- emb_tensor[i].unsqueeze(0),
123
- emb_tensor[j].unsqueeze(0)
124
- ).item()
125
- similarities.append(sim)
126
-
127
- # Lower similarity = higher diversity = higher reward
128
- avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
129
- diversity_reward = 1.0 - avg_similarity # 0 = identical, 1 = orthogonal
130
- except Exception:
131
- diversity_reward = 0.0
132
  else:
133
- diversity_reward = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- return max(0.0, diversity_reward - duplicate_penalty)
 
 
136
 
 
 
137
 
138
- def compute_length_reward(text: str) -> float:
139
- """Reward appropriate length (not too short, not too long)."""
140
- lines = [l for l in text.strip().split("\n") if l.strip()]
141
 
142
- # Ideal: 3-6 lines
143
- if 3 <= len(lines) <= 6:
144
- return 0.2
145
- elif 2 <= len(lines) <= 7:
146
- return 0.1
147
- else:
148
- return 0.0
149
 
150
 
151
  class QMDRewardFunction:
152
- """Combined reward function for QMD query expansion."""
153
- __name__ = "qmd_format_diversity_reward"
154
-
155
- def __init__(self):
156
- # Load a small embedding model for diversity computation
157
- print("Loading embedding model for diversity reward...")
158
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
159
- print("Embedding model loaded.")
160
 
161
  def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
162
  """Compute rewards for a batch of completions."""
163
  rewards = []
164
 
165
- for completion in completions:
166
- # Extract just the generated part (after prompt)
167
- text = completion
168
-
169
- # Compute component rewards
170
- format_r = compute_format_reward(text)
171
- diversity_r = compute_diversity_reward(text, self.embedder)
172
- length_r = compute_length_reward(text)
173
 
174
- # Weighted combination
175
- total = (
176
- 0.5 * format_r + # Format is most important
177
- 0.35 * diversity_r + # Diversity is second
178
- 0.15 * length_r # Length is minor
179
- )
180
-
181
- rewards.append(total)
182
 
183
  return rewards
184
 
@@ -194,9 +239,11 @@ def main():
194
  help="SFT model to use as starting point")
195
  parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
196
  help="Base model (for loading tokenizer)")
197
- parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
198
  help="Output model name on Hub")
199
  parser.add_argument("--epochs", type=int, default=1)
 
 
200
  parser.add_argument("--dry-run", action="store_true")
201
  args = parser.parse_args()
202
 
@@ -206,6 +253,7 @@ def main():
206
  print(f" Base Model: {args.base_model}")
207
  print(f" Output: {args.output}")
208
  print(f" Epochs: {args.epochs}")
 
209
  return
210
 
211
  # Load dataset (just prompts needed for GRPO)
@@ -217,7 +265,7 @@ def main():
217
  return {"prompt": example["messages"][0]["content"]}
218
 
219
  dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
220
- dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset)))) # Use subset for GRPO
221
  print(f"Using {len(dataset)} prompts for GRPO")
222
 
223
  # Load tokenizer
@@ -234,18 +282,17 @@ def main():
234
  device_map="auto",
235
  )
236
  model = PeftModel.from_pretrained(base_model, args.sft_model)
237
- model = model.merge_and_unload() # Merge LoRA weights
238
  print("Model loaded and LoRA merged.")
239
 
240
- # Add new LoRA adapter for GRPO training
241
- from peft import get_peft_model
242
  grpo_lora_config = LoraConfig(
243
- r=8,
244
- lora_alpha=16,
245
  lora_dropout=0.05,
246
  bias="none",
247
  task_type="CAUSAL_LM",
248
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
249
  )
250
  model = get_peft_model(model, grpo_lora_config)
251
  model.print_trainable_parameters()
@@ -254,21 +301,29 @@ def main():
254
  # Initialize reward function
255
  reward_fn = QMDRewardFunction()
256
 
257
- # GRPO config
 
 
 
 
 
 
 
258
  config = GRPOConfig(
259
- output_dir="qmd-expansion-grpo",
260
  push_to_hub=True,
261
  hub_model_id=args.output,
262
 
263
- # GRPO specific
264
- num_generations=4, # Generate 4 completions per prompt
265
- max_completion_length=256,
266
 
267
- # Training
268
  num_train_epochs=args.epochs,
269
  per_device_train_batch_size=2,
270
- gradient_accumulation_steps=4,
271
- learning_rate=5e-6, # Lower LR for RL
 
272
 
273
  # Logging
274
  logging_steps=10,
@@ -276,8 +331,8 @@ def main():
276
 
277
  # Monitoring
278
  report_to="trackio",
279
- project="qmd-query-expansion-grpo",
280
- run_name="grpo-format-diversity",
281
  )
282
 
283
  # Create trainer
 
8
  # "trackio",
9
  # "datasets",
10
  # "bitsandbytes",
 
11
  # ]
12
  # ///
13
  """
14
  GRPO (Group Relative Policy Optimization) training for QMD query expansion.
15
 
16
+ Uses the comprehensive scoring system from SCORING.md:
17
+ - Format (30%): Must have lex: and vec: prefixes
18
+ - Diversity (30%): No echoing query, diverse expansions
19
+ - Hyde (20%): Concise, no newlines, no repetition
20
+ - Quality (20%): lex=keywords, vec=natural language
21
 
22
  Usage:
23
  uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
 
26
  import re
27
  import torch
28
  import trackio
29
+ from collections import Counter
30
  from datasets import load_dataset
31
+ from peft import LoraConfig, PeftModel, get_peft_model
32
  from transformers import AutoModelForCausalLM, AutoTokenizer
33
  from trl import GRPOTrainer, GRPOConfig
34
+
35
+ STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
36
 
37
  # ============================================================================
38
+ # Scoring Functions (from SCORING.md)
39
  # ============================================================================
40
 
41
  def parse_expansion(text: str) -> dict:
42
+ """Parse expansion into structured format."""
43
+ lines = text.strip().split("\n")
44
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
45
 
46
+ for line in lines:
47
  line = line.strip()
48
+ if not line:
49
+ continue
50
  if line.startswith("lex:"):
51
  result["lex"].append(line[4:].strip())
52
  elif line.startswith("vec:"):
53
  result["vec"].append(line[4:].strip())
54
  elif line.startswith("hyde:"):
55
  result["hyde"].append(line[5:].strip())
56
+ else:
57
+ result["invalid"].append(line)
58
 
59
  return result
60
 
61
 
62
+ def edit_distance_simple(a: str, b: str) -> int:
63
+ """Simple word-level edit distance."""
64
+ words_a = set(a.lower().split())
65
+ words_b = set(b.lower().split())
66
+ return len(words_a ^ words_b)
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
68
 
69
+ def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
70
+ """Check if two strings are sufficiently different."""
71
+ a, b = a.lower().strip(), b.lower().strip()
72
+ if a == b:
73
+ return False
74
+ if a in b or b in a:
75
+ return False
76
+ return edit_distance_simple(a, b) >= min_distance
77
 
 
78
 
79
+ def echoes_query(expansion: str, query: str) -> bool:
80
+ """Check if expansion is just echoing the query."""
81
+ exp = expansion.lower().strip()
82
+ q = query.lower().strip()
83
+ if exp == q:
84
+ return True
85
+ if q in exp and len(exp) < len(q) + 10:
86
+ return True
87
+ return False
88
 
 
89
 
90
+ def word_repetition_penalty(text: str) -> int:
91
+ """Count penalty for repeated words (excluding stopwords)."""
92
+ words = re.findall(r'\b\w+\b', text.lower())
93
+ counts = Counter(words)
94
+ penalty = 0
95
+ for word, count in counts.items():
96
+ if count >= 3 and word not in STOPWORDS and len(word) > 2:
97
+ penalty += (count - 2) * 2
98
+ return penalty
99
 
100
+
101
+ def score_expansion(query: str, expansion: str) -> float:
102
  """
103
+ Score an expansion based on SCORING.md criteria.
104
+ Returns normalized score 0.0-1.0 for RL reward.
 
105
  """
106
+ parsed = parse_expansion(expansion)
107
+
108
+ # === FORMAT (0-30) ===
109
+ format_score = 0
110
+ if parsed["lex"]:
111
+ format_score += 10
112
+ if parsed["vec"]:
113
+ format_score += 10
114
+ if not parsed["invalid"]:
115
+ format_score += 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  else:
117
+ format_score += max(0, 10 - len(parsed["invalid"]) * 5)
118
+
119
+ # === DIVERSITY (0-30) ===
120
+ diversity_score = 0
121
+
122
+ # 2+ different types
123
+ types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
124
+ if types_present >= 2:
125
+ diversity_score += 10
126
+
127
+ # 2+ total expansions
128
+ total_expansions = len(parsed["lex"]) + len(parsed["vec"])
129
+ if total_expansions >= 2:
130
+ diversity_score += 5
131
+
132
+ # Lex diversity
133
+ lex_score = 5
134
+ for i, a in enumerate(parsed["lex"]):
135
+ for b in parsed["lex"][i+1:]:
136
+ if not is_diverse(a, b, 2):
137
+ lex_score -= 2
138
+ diversity_score += max(0, lex_score)
139
+
140
+ # Vec diversity
141
+ vec_score = 5
142
+ for i, a in enumerate(parsed["vec"]):
143
+ for b in parsed["vec"][i+1:]:
144
+ if not is_diverse(a, b, 3):
145
+ vec_score -= 2
146
+ diversity_score += max(0, vec_score)
147
+
148
+ # Don't echo query
149
+ echo_score = 5
150
+ for exp in parsed["lex"] + parsed["vec"]:
151
+ if echoes_query(exp, query):
152
+ echo_score -= 3 # Heavier penalty for echoing
153
+ diversity_score += max(0, echo_score)
154
+
155
+ # === HYDE (0-20) ===
156
+ hyde_score = 0
157
+ if parsed["hyde"]:
158
+ hyde_text = parsed["hyde"][0]
159
+ hyde_score += 5 # Present
160
+
161
+ # Length check (50-200 chars ideal)
162
+ hyde_len = len(hyde_text)
163
+ if 50 <= hyde_len <= 200:
164
+ hyde_score += 5
165
+ elif hyde_len < 50:
166
+ hyde_score += 2
167
+
168
+ # No newlines
169
+ if "\n" not in hyde_text:
170
+ hyde_score += 5
171
+
172
+ # No repetition
173
+ rep_penalty = word_repetition_penalty(hyde_text)
174
+ hyde_score += max(0, 5 - rep_penalty)
175
+
176
+ # === QUALITY (0-20) ===
177
+ quality_score = 10 # Base
178
+
179
+ # Lex should be shorter than vec
180
+ if parsed["lex"] and parsed["vec"]:
181
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
182
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
183
+ if avg_lex <= avg_vec:
184
+ quality_score += 5
185
+
186
+ # Vec should be natural language
187
+ if parsed["vec"]:
188
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
189
+ if natural == len(parsed["vec"]):
190
+ quality_score += 5
191
+ else:
192
+ quality_score += 2
193
 
194
+ # === TOTAL ===
195
+ total = format_score + diversity_score + hyde_score + quality_score
196
+ max_possible = 100 if parsed["hyde"] else 80
197
 
198
+ # Normalize to 0-1
199
+ return total / max_possible
200
 
 
 
 
201
 
202
+ def extract_query_from_prompt(prompt: str) -> str:
203
+ """Extract the query from the prompt template."""
204
+ # Prompt format: "Expand this search query:\n\n{query}"
205
+ if "Expand this search query:" in prompt:
206
+ return prompt.split("Expand this search query:")[-1].strip()
207
+ return prompt.strip()
 
208
 
209
 
210
  class QMDRewardFunction:
211
+ """Reward function using comprehensive SCORING.md criteria."""
212
+ __name__ = "qmd_scoring_reward"
 
 
 
 
 
 
213
 
214
  def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
215
  """Compute rewards for a batch of completions."""
216
  rewards = []
217
 
218
+ for i, completion in enumerate(completions):
219
+ # Get the query from prompt if available
220
+ query = ""
221
+ if prompts and i < len(prompts):
222
+ query = extract_query_from_prompt(prompts[i])
 
 
 
223
 
224
+ # Score using comprehensive system
225
+ score = score_expansion(query, completion)
226
+ rewards.append(score)
 
 
 
 
 
227
 
228
  return rewards
229
 
 
239
  help="SFT model to use as starting point")
240
  parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
241
  help="Base model (for loading tokenizer)")
242
+ parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo-v2",
243
  help="Output model name on Hub")
244
  parser.add_argument("--epochs", type=int, default=1)
245
+ parser.add_argument("--lr", type=float, default=1e-6,
246
+ help="Learning rate (lower for stability)")
247
  parser.add_argument("--dry-run", action="store_true")
248
  args = parser.parse_args()
249
 
 
253
  print(f" Base Model: {args.base_model}")
254
  print(f" Output: {args.output}")
255
  print(f" Epochs: {args.epochs}")
256
+ print(f" LR: {args.lr}")
257
  return
258
 
259
  # Load dataset (just prompts needed for GRPO)
 
265
  return {"prompt": example["messages"][0]["content"]}
266
 
267
  dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
268
+ dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset))))
269
  print(f"Using {len(dataset)} prompts for GRPO")
270
 
271
  # Load tokenizer
 
282
  device_map="auto",
283
  )
284
  model = PeftModel.from_pretrained(base_model, args.sft_model)
285
+ model = model.merge_and_unload()
286
  print("Model loaded and LoRA merged.")
287
 
288
+ # Add new LoRA adapter for GRPO training (smaller rank for stability)
 
289
  grpo_lora_config = LoraConfig(
290
+ r=4, # Smaller rank for more stable RL
291
+ lora_alpha=8,
292
  lora_dropout=0.05,
293
  bias="none",
294
  task_type="CAUSAL_LM",
295
+ target_modules=["q_proj", "v_proj"], # Fewer modules for stability
296
  )
297
  model = get_peft_model(model, grpo_lora_config)
298
  model.print_trainable_parameters()
 
301
  # Initialize reward function
302
  reward_fn = QMDRewardFunction()
303
 
304
+ # Test reward function
305
+ print("\nTesting reward function...")
306
+ test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
307
+ test_bad = "auth is important for security"
308
+ print(f" Good output score: {score_expansion('auth', test_good):.2f}")
309
+ print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
310
+
311
+ # GRPO config with conservative settings
312
  config = GRPOConfig(
313
+ output_dir="qmd-expansion-grpo-v2",
314
  push_to_hub=True,
315
  hub_model_id=args.output,
316
 
317
+ # GRPO specific - conservative
318
+ num_generations=4,
319
+ max_completion_length=200, # Shorter to avoid rambling
320
 
321
+ # Training - very conservative
322
  num_train_epochs=args.epochs,
323
  per_device_train_batch_size=2,
324
+ gradient_accumulation_steps=8,
325
+ learning_rate=args.lr,
326
+ max_grad_norm=0.5, # Clip gradients more aggressively
327
 
328
  # Logging
329
  logging_steps=10,
 
331
 
332
  # Monitoring
333
  report_to="trackio",
334
+ project="qmd-query-expansion-grpo-v2",
335
+ run_name="grpo-scoring-v2",
336
  )
337
 
338
  # Create trainer