tobil commited on
Commit
bf6bf1b
·
verified ·
1 Parent(s): 36fb469

Upload train_1.7B_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_1.7B_grpo.py +16 -8
train_1.7B_grpo.py CHANGED
@@ -290,7 +290,11 @@ def score_expansion(query: str, expansion: str) -> float:
290
 
291
  def extract_query_from_prompt(prompt: str) -> str:
292
  if "Expand this search query:" in prompt:
293
- return prompt.split("Expand this search query:")[-1].strip()
 
 
 
 
294
  return prompt.strip()
295
 
296
 
@@ -323,23 +327,27 @@ def main():
323
  print("Logging in to HuggingFace Hub...")
324
  login(token=hf_token)
325
 
 
 
 
 
 
 
326
  # Load dataset
327
  print("Loading dataset...")
328
  dataset = load_dataset(DATASET, split="train")
329
 
330
  def extract_prompt(example):
331
- return {"prompt": example["messages"][0]["content"]}
 
 
 
 
332
 
333
  dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
334
  dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
335
  print(f"Using {len(dataset)} prompts for GRPO")
336
 
337
- # Load tokenizer and model
338
- print(f"Loading tokenizer from {BASE_MODEL}...")
339
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
340
- if tokenizer.pad_token is None:
341
- tokenizer.pad_token = tokenizer.eos_token
342
-
343
  print(f"Loading SFT model from {SFT_MODEL}...")
344
  base_model = AutoModelForCausalLM.from_pretrained(
345
  BASE_MODEL,
 
290
 
291
  def extract_query_from_prompt(prompt: str) -> str:
292
  if "Expand this search query:" in prompt:
293
+ query = prompt.split("Expand this search query:")[-1].strip()
294
+ # Remove chat template artifacts if present
295
+ if "<|im_end|>" in query:
296
+ query = query.split("<|im_end|>")[0].strip()
297
+ return query
298
  return prompt.strip()
299
 
300
 
 
327
  print("Logging in to HuggingFace Hub...")
328
  login(token=hf_token)
329
 
330
+ # Load tokenizer first (needed for chat template)
331
+ print(f"Loading tokenizer from {BASE_MODEL}...")
332
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
333
+ if tokenizer.pad_token is None:
334
+ tokenizer.pad_token = tokenizer.eos_token
335
+
336
  # Load dataset
337
  print("Loading dataset...")
338
  dataset = load_dataset(DATASET, split="train")
339
 
340
  def extract_prompt(example):
341
+ # Apply chat template so model sees the same format as SFT training
342
+ content = example["messages"][0]["content"]
343
+ messages = [{"role": "user", "content": content}]
344
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
345
+ return {"prompt": formatted}
346
 
347
  dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
348
  dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
349
  print(f"Using {len(dataset)} prompts for GRPO")
350
 
 
 
 
 
 
 
351
  print(f"Loading SFT model from {SFT_MODEL}...")
352
  base_model = AutoModelForCausalLM.from_pretrained(
353
  BASE_MODEL,