tobil commited on
Commit
b186bb2
·
verified ·
1 Parent(s): c7967b0

Upload train_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo.py +11 -0
train_grpo.py CHANGED
@@ -5,6 +5,7 @@
5
  # "peft>=0.7.0",
6
  # "transformers>=4.45.0",
7
  # "accelerate>=0.24.0",
 
8
  # "trackio",
9
  # "datasets",
10
  # "bitsandbytes",
@@ -23,11 +24,13 @@ Usage:
23
  uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
24
  """
25
 
 
26
  import re
27
  import torch
28
  import trackio
29
  from collections import Counter
30
  from datasets import load_dataset
 
31
  from peft import LoraConfig, PeftModel, get_peft_model
32
  from transformers import AutoModelForCausalLM, AutoTokenizer
33
  from trl import GRPOTrainer, GRPOConfig
@@ -256,6 +259,14 @@ def main():
256
  print(f" LR: {args.lr}")
257
  return
258
 
 
 
 
 
 
 
 
 
259
  # Load dataset (just prompts needed for GRPO)
260
  print("Loading dataset...")
261
  dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
 
5
  # "peft>=0.7.0",
6
  # "transformers>=4.45.0",
7
  # "accelerate>=0.24.0",
8
+ # "huggingface_hub>=0.20.0",
9
  # "trackio",
10
  # "datasets",
11
  # "bitsandbytes",
 
24
  uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
25
  """
26
 
27
+ import os
28
  import re
29
  import torch
30
  import trackio
31
  from collections import Counter
32
  from datasets import load_dataset
33
+ from huggingface_hub import login
34
  from peft import LoraConfig, PeftModel, get_peft_model
35
  from transformers import AutoModelForCausalLM, AutoTokenizer
36
  from trl import GRPOTrainer, GRPOConfig
 
259
  print(f" LR: {args.lr}")
260
  return
261
 
262
+ # Login to HuggingFace Hub
263
+ hf_token = os.environ.get("HF_TOKEN")
264
+ if hf_token:
265
+ print("Logging in to HuggingFace Hub...")
266
+ login(token=hf_token)
267
+ else:
268
+ print("Warning: HF_TOKEN not set, will try cached login")
269
+
270
  # Load dataset (just prompts needed for GRPO)
271
  print("Loading dataset...")
272
  dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")