Upload train_1.7B_grpo.py with huggingface_hub
Browse files- train_1.7B_grpo.py +16 -8
train_1.7B_grpo.py
CHANGED
|
@@ -290,7 +290,11 @@ def score_expansion(query: str, expansion: str) -> float:
|
|
| 290 |
|
| 291 |
def extract_query_from_prompt(prompt: str) -> str:
|
| 292 |
if "Expand this search query:" in prompt:
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
return prompt.strip()
|
| 295 |
|
| 296 |
|
|
@@ -323,23 +327,27 @@ def main():
|
|
| 323 |
print("Logging in to HuggingFace Hub...")
|
| 324 |
login(token=hf_token)
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
# Load dataset
|
| 327 |
print("Loading dataset...")
|
| 328 |
dataset = load_dataset(DATASET, split="train")
|
| 329 |
|
| 330 |
def extract_prompt(example):
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
| 334 |
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
| 335 |
print(f"Using {len(dataset)} prompts for GRPO")
|
| 336 |
|
| 337 |
-
# Load tokenizer and model
|
| 338 |
-
print(f"Loading tokenizer from {BASE_MODEL}...")
|
| 339 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 340 |
-
if tokenizer.pad_token is None:
|
| 341 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 342 |
-
|
| 343 |
print(f"Loading SFT model from {SFT_MODEL}...")
|
| 344 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 345 |
BASE_MODEL,
|
|
|
|
| 290 |
|
| 291 |
def extract_query_from_prompt(prompt: str) -> str:
|
| 292 |
if "Expand this search query:" in prompt:
|
| 293 |
+
query = prompt.split("Expand this search query:")[-1].strip()
|
| 294 |
+
# Remove chat template artifacts if present
|
| 295 |
+
if "<|im_end|>" in query:
|
| 296 |
+
query = query.split("<|im_end|>")[0].strip()
|
| 297 |
+
return query
|
| 298 |
return prompt.strip()
|
| 299 |
|
| 300 |
|
|
|
|
| 327 |
print("Logging in to HuggingFace Hub...")
|
| 328 |
login(token=hf_token)
|
| 329 |
|
| 330 |
+
# Load tokenizer first (needed for chat template)
|
| 331 |
+
print(f"Loading tokenizer from {BASE_MODEL}...")
|
| 332 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 333 |
+
if tokenizer.pad_token is None:
|
| 334 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 335 |
+
|
| 336 |
# Load dataset
|
| 337 |
print("Loading dataset...")
|
| 338 |
dataset = load_dataset(DATASET, split="train")
|
| 339 |
|
| 340 |
def extract_prompt(example):
|
| 341 |
+
# Apply chat template so model sees the same format as SFT training
|
| 342 |
+
content = example["messages"][0]["content"]
|
| 343 |
+
messages = [{"role": "user", "content": content}]
|
| 344 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 345 |
+
return {"prompt": formatted}
|
| 346 |
|
| 347 |
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
| 348 |
dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
|
| 349 |
print(f"Using {len(dataset)} prompts for GRPO")
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
print(f"Loading SFT model from {SFT_MODEL}...")
|
| 352 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 353 |
BASE_MODEL,
|