tobil commited on
Commit
5775965
·
verified ·
1 Parent(s): 163aa9c

Add 4B GRPO training script

Browse files
Files changed (1) hide show
  1. train_4B_grpo.py +402 -0
train_4B_grpo.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.45.0",
7
+ # "accelerate>=0.24.0",
8
+ # "huggingface_hub>=0.20.0",
9
+ # "trackio",
10
+ # "datasets",
11
+ # "bitsandbytes",
12
+ # ]
13
+ # ///
14
+ """
15
+ GRPO training for Qwen3-4B query expansion model.
16
+ Trains on top of merged SFT weights with reward function.
17
+ """
18
+
19
+ import os
20
+ import re
21
+ from collections import Counter
22
+
23
+ import torch
24
+ import trackio
25
+ from datasets import load_dataset
26
+ from huggingface_hub import login
27
+ from peft import LoraConfig, PeftModel, get_peft_model
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ from trl import GRPOTrainer, GRPOConfig
30
+
31
+ # ==================== REWARD FUNCTION ====================
32
+
33
+ STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
34
+ KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
35
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
36
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
37
+
38
+ GENERIC_LEX_PHRASES = {
39
+ 'find information about', 'search for', 'look up', 'get information',
40
+ 'learn about', 'information on', 'details about', 'find out about',
41
+ 'what is', 'how to', 'guide to', 'help with'
42
+ }
43
+
44
+
45
+ def extract_named_entities(query: str) -> set:
46
+ """Extract named entities from query using simple heuristics."""
47
+ entities = set()
48
+ words = query.split()
49
+ prev_was_entity = False
50
+
51
+ for i, word in enumerate(words):
52
+ clean = word.strip('.,!?:;()[]"\'')
53
+ if not clean:
54
+ prev_was_entity = False
55
+ continue
56
+
57
+ is_entity = False
58
+
59
+ if clean.isupper() and len(clean) >= 2:
60
+ entities.add(clean.lower())
61
+ is_entity = True
62
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
63
+ entities.add(clean.lower())
64
+ is_entity = True
65
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
66
+ entities.add(clean.lower())
67
+ is_entity = True
68
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
69
+ entities.add(clean.lower())
70
+ is_entity = True
71
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
72
+ entities.add(clean.lower())
73
+ is_entity = True
74
+
75
+ prev_was_entity = is_entity
76
+
77
+ return entities
78
+
79
+
80
+ def get_key_terms(query: str) -> set:
81
+ words = set(query.lower().split())
82
+ return words - KEY_TERM_STOPWORDS
83
+
84
+
85
+ def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
86
+ key_terms = get_key_terms(query)
87
+ if not key_terms:
88
+ return True
89
+ lex_words = set(lex_line.lower().split())
90
+ return bool(key_terms & lex_words)
91
+
92
+
93
+ def lex_preserves_entities(lex_line: str, entities: set) -> bool:
94
+ if not entities:
95
+ return True
96
+ lex_lower = lex_line.lower()
97
+ return any(entity in lex_lower for entity in entities)
98
+
99
+
100
+ def lex_is_generic(lex_line: str) -> bool:
101
+ lex_lower = lex_line.lower().strip()
102
+ for phrase in GENERIC_LEX_PHRASES:
103
+ if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
104
+ remaining = lex_lower
105
+ for word in phrase.split():
106
+ remaining = remaining.replace(word, '', 1).strip()
107
+ if len(remaining) < 3:
108
+ return True
109
+ return False
110
+
111
+
112
+ def parse_expansion(text: str) -> dict:
113
+ lines = text.strip().split("\n")
114
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
115
+ for line in lines:
116
+ line = line.strip()
117
+ if not line:
118
+ continue
119
+ if line.startswith("lex:"):
120
+ result["lex"].append(line[4:].strip())
121
+ elif line.startswith("vec:"):
122
+ result["vec"].append(line[4:].strip())
123
+ elif line.startswith("hyde:"):
124
+ result["hyde"].append(line[5:].strip())
125
+ else:
126
+ result["invalid"].append(line)
127
+ return result
128
+
129
+
130
+ def edit_distance_simple(a: str, b: str) -> int:
131
+ words_a = set(a.lower().split())
132
+ words_b = set(b.lower().split())
133
+ return len(words_a ^ words_b)
134
+
135
+
136
+ def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
137
+ a, b = a.lower().strip(), b.lower().strip()
138
+ if a == b:
139
+ return False
140
+ if a in b or b in a:
141
+ return False
142
+ return edit_distance_simple(a, b) >= min_distance
143
+
144
+
145
+ def echoes_query(expansion: str, query: str) -> bool:
146
+ exp = expansion.lower().strip()
147
+ q = query.lower().strip()
148
+ if exp == q:
149
+ return True
150
+ if q in exp and len(exp) < len(q) + 10:
151
+ return True
152
+ return False
153
+
154
+
155
+ def word_repetition_penalty(text: str) -> int:
156
+ words = re.findall(r'\b\w+\b', text.lower())
157
+ counts = Counter(words)
158
+ penalty = 0
159
+ for word, count in counts.items():
160
+ if count >= 3 and word not in STOPWORDS and len(word) > 2:
161
+ penalty += (count - 2) * 2
162
+ return penalty
163
+
164
+
165
+ def score_expansion(query: str, expansion: str) -> float:
166
+ """Score expansion. Returns 0.0-1.0 for RL reward."""
167
+ text = expansion.strip()
168
+
169
+ # HARD FAIL: Chat template artifacts
170
+ if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
171
+ '\nassistant\n', '\nuser\n', '<|endoftext|>']):
172
+ return 0.0
173
+
174
+ # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
175
+ for line in text.split("\n"):
176
+ line = line.strip()
177
+ if not line:
178
+ continue
179
+ if not line.startswith(("lex:", "vec:", "hyde:")):
180
+ return 0.0
181
+
182
+ parsed = parse_expansion(expansion)
183
+
184
+ # FORMAT (0-30)
185
+ format_score = 0
186
+ if parsed["lex"]:
187
+ format_score += 10
188
+ if parsed["vec"]:
189
+ format_score += 10
190
+ format_score += 10
191
+
192
+ # DIVERSITY (0-30)
193
+ diversity_score = 0
194
+ types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
195
+ if types_present >= 2:
196
+ diversity_score += 10
197
+ total_expansions = len(parsed["lex"]) + len(parsed["vec"])
198
+ if total_expansions >= 2:
199
+ diversity_score += 5
200
+
201
+ lex_score = 5
202
+ for i, a in enumerate(parsed["lex"]):
203
+ for b in parsed["lex"][i+1:]:
204
+ if not is_diverse(a, b, 2):
205
+ lex_score -= 2
206
+ diversity_score += max(0, lex_score)
207
+
208
+ vec_score = 5
209
+ for i, a in enumerate(parsed["vec"]):
210
+ for b in parsed["vec"][i+1:]:
211
+ if not is_diverse(a, b, 3):
212
+ vec_score -= 2
213
+ diversity_score += max(0, vec_score)
214
+
215
+ echo_score = 5
216
+ for exp in parsed["lex"] + parsed["vec"]:
217
+ if echoes_query(exp, query):
218
+ echo_score -= 3
219
+ diversity_score += max(0, echo_score)
220
+
221
+ # HYDE (0-20)
222
+ hyde_score = 0
223
+ if parsed["hyde"]:
224
+ hyde_text = parsed["hyde"][0]
225
+ hyde_score += 5
226
+ hyde_len = len(hyde_text)
227
+ if 50 <= hyde_len <= 200:
228
+ hyde_score += 5
229
+ elif hyde_len < 50:
230
+ hyde_score += 2
231
+ if "\n" not in hyde_text:
232
+ hyde_score += 5
233
+ rep_penalty = word_repetition_penalty(hyde_text)
234
+ hyde_score += max(0, 5 - rep_penalty)
235
+
236
+ # QUALITY (0-20)
237
+ quality_score = 5
238
+ if parsed["lex"] and parsed["vec"]:
239
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
240
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
241
+ if avg_lex <= avg_vec:
242
+ quality_score += 5
243
+ if parsed["vec"]:
244
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
245
+ if natural == len(parsed["vec"]):
246
+ quality_score += 5
247
+ else:
248
+ quality_score += 2
249
+ if parsed["lex"]:
250
+ lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
251
+ if lex_with_terms == len(parsed["lex"]):
252
+ quality_score += 5
253
+ elif lex_with_terms > 0:
254
+ quality_score += 2
255
+
256
+ # NAMED ENTITY PRESERVATION
257
+ entity_score = 0
258
+ entities = extract_named_entities(query)
259
+ if entities and parsed["lex"]:
260
+ lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
261
+ if lex_with_entities == len(parsed["lex"]):
262
+ entity_score += 15
263
+ elif lex_with_entities > 0:
264
+ entity_score += 5
265
+ else:
266
+ entity_score -= 30
267
+
268
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
269
+ entity_score -= generic_count * 15
270
+
271
+ if parsed["vec"]:
272
+ vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
273
+ if vec_with_entities > 0:
274
+ entity_score += 5
275
+ elif not entities:
276
+ entity_score = 10
277
+
278
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score
279
+ max_possible = 120 if parsed["hyde"] else 100
280
+ return max(0.0, min(1.0, total / max_possible))
281
+
282
+
283
+ def extract_query_from_prompt(prompt: str) -> str:
284
+ if "Expand this search query:" in prompt:
285
+ return prompt.split("Expand this search query:")[-1].strip()
286
+ return prompt.strip()
287
+
288
+
289
+ class QMDRewardFunction:
290
+ __name__ = "qmd_scoring_reward"
291
+
292
+ def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
293
+ rewards = []
294
+ for i, completion in enumerate(completions):
295
+ query = ""
296
+ if prompts and i < len(prompts):
297
+ query = extract_query_from_prompt(prompts[i])
298
+ score = score_expansion(query, completion)
299
+ rewards.append(score)
300
+ return rewards
301
+
302
+
303
+ # ==================== MAIN ====================
304
+
305
+ def main():
306
+ # Config
307
+ SFT_MODEL = "tobil/qmd-query-expansion-4B-sft"
308
+ BASE_MODEL = "Qwen/Qwen3-4B"
309
+ OUTPUT_MODEL = "tobil/qmd-query-expansion-4B-grpo"
310
+ DATASET = "tobil/qmd-query-expansion-train-v2"
311
+
312
+ # Login
313
+ hf_token = os.environ.get("HF_TOKEN")
314
+ if hf_token:
315
+ print("Logging in to HuggingFace Hub...")
316
+ login(token=hf_token)
317
+
318
+ # Load dataset
319
+ print("Loading dataset...")
320
+ dataset = load_dataset(DATASET, split="train")
321
+
322
+ def extract_prompt(example):
323
+ return {"prompt": example["messages"][0]["content"]}
324
+
325
+ dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
326
+ dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
327
+ print(f"Using {len(dataset)} prompts for GRPO")
328
+
329
+ # Load tokenizer and model
330
+ print(f"Loading tokenizer from {BASE_MODEL}...")
331
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
332
+ if tokenizer.pad_token is None:
333
+ tokenizer.pad_token = tokenizer.eos_token
334
+
335
+ print(f"Loading SFT model from {SFT_MODEL}...")
336
+ base_model = AutoModelForCausalLM.from_pretrained(
337
+ BASE_MODEL,
338
+ torch_dtype=torch.bfloat16,
339
+ device_map="auto",
340
+ )
341
+ model = PeftModel.from_pretrained(base_model, SFT_MODEL)
342
+ model = model.merge_and_unload()
343
+ print("Model loaded and LoRA merged.")
344
+
345
+ # Add LoRA for GRPO
346
+ grpo_lora_config = LoraConfig(
347
+ r=4,
348
+ lora_alpha=8,
349
+ lora_dropout=0.05,
350
+ bias="none",
351
+ task_type="CAUSAL_LM",
352
+ target_modules=["q_proj", "v_proj"],
353
+ )
354
+ model = get_peft_model(model, grpo_lora_config)
355
+ model.print_trainable_parameters()
356
+
357
+ # GRPO config
358
+ config = GRPOConfig(
359
+ output_dir="qmd-query-expansion-4B-grpo",
360
+ push_to_hub=True,
361
+ hub_model_id=OUTPUT_MODEL,
362
+
363
+ num_generations=4,
364
+ max_completion_length=200,
365
+
366
+ num_train_epochs=1,
367
+ per_device_train_batch_size=1, # Smaller for 4B model
368
+ gradient_accumulation_steps=16, # Compensate with more accumulation
369
+ learning_rate=5e-7,
370
+ max_grad_norm=0.5,
371
+ max_steps=200,
372
+
373
+ logging_steps=10,
374
+ save_strategy="epoch",
375
+
376
+ report_to="trackio",
377
+ project="qmd-query-expansion",
378
+ run_name="qwen3-4b-grpo",
379
+ )
380
+
381
+ # Train
382
+ print("Initializing GRPO trainer...")
383
+ trainer = GRPOTrainer(
384
+ model=model,
385
+ processing_class=tokenizer,
386
+ args=config,
387
+ train_dataset=dataset,
388
+ reward_funcs=[QMDRewardFunction()],
389
+ )
390
+
391
+ print("Starting GRPO training...")
392
+ trainer.train()
393
+
394
+ print("Pushing to Hub...")
395
+ trainer.push_to_hub()
396
+
397
+ trackio.finish()
398
+ print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()