Upload train_grpo.py with huggingface_hub
Browse files- train_grpo.py +11 -0
train_grpo.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
| 5 |
# "peft>=0.7.0",
|
| 6 |
# "transformers>=4.45.0",
|
| 7 |
# "accelerate>=0.24.0",
|
|
|
|
| 8 |
# "trackio",
|
| 9 |
# "datasets",
|
| 10 |
# "bitsandbytes",
|
|
@@ -23,11 +24,13 @@ Usage:
|
|
| 23 |
uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
|
| 24 |
"""
|
| 25 |
|
|
|
|
| 26 |
import re
|
| 27 |
import torch
|
| 28 |
import trackio
|
| 29 |
from collections import Counter
|
| 30 |
from datasets import load_dataset
|
|
|
|
| 31 |
from peft import LoraConfig, PeftModel, get_peft_model
|
| 32 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 33 |
from trl import GRPOTrainer, GRPOConfig
|
|
@@ -256,6 +259,14 @@ def main():
|
|
| 256 |
print(f" LR: {args.lr}")
|
| 257 |
return
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
# Load dataset (just prompts needed for GRPO)
|
| 260 |
print("Loading dataset...")
|
| 261 |
dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
|
|
|
|
| 5 |
# "peft>=0.7.0",
|
| 6 |
# "transformers>=4.45.0",
|
| 7 |
# "accelerate>=0.24.0",
|
| 8 |
+
# "huggingface_hub>=0.20.0",
|
| 9 |
# "trackio",
|
| 10 |
# "datasets",
|
| 11 |
# "bitsandbytes",
|
|
|
|
| 24 |
uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
import os
|
| 28 |
import re
|
| 29 |
import torch
|
| 30 |
import trackio
|
| 31 |
from collections import Counter
|
| 32 |
from datasets import load_dataset
|
| 33 |
+
from huggingface_hub import login
|
| 34 |
from peft import LoraConfig, PeftModel, get_peft_model
|
| 35 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 36 |
from trl import GRPOTrainer, GRPOConfig
|
|
|
|
| 259 |
print(f" LR: {args.lr}")
|
| 260 |
return
|
| 261 |
|
| 262 |
+
# Login to HuggingFace Hub
|
| 263 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 264 |
+
if hf_token:
|
| 265 |
+
print("Logging in to HuggingFace Hub...")
|
| 266 |
+
login(token=hf_token)
|
| 267 |
+
else:
|
| 268 |
+
print("Warning: HF_TOKEN not set, will try cached login")
|
| 269 |
+
|
| 270 |
# Load dataset (just prompts needed for GRPO)
|
| 271 |
print("Loading dataset...")
|
| 272 |
dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
|