testing / train_sharegpt_polar.py
Delta-Vector's picture
Upload folder using huggingface_hub
86dd177 verified
"""GRPO training entrypoint for ShareGPT POLAR environment."""
from __future__ import annotations
import os
import verifiers as vf
from xtuner.utils import RewardModelClient
RUN_NAME = "sharegpt-polar"
MODEL_NAME = "NewEden/Snwy-SFT-GRPO-base"
DATA_PATH = "/home/Ubuntu/Mango/verifiers/new.jsonl"
SERVER_ADDRESS = "greene-cannon-republic-expect.trycloudflare.com"
REWARD_MODEL = "internlm/POLAR-7B"
# Training hyperparameters
PER_DEVICE_TRAIN_BATCH_SIZE = 2
NUM_GENERATIONS = 8
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-6
BETA = 0.1
MAX_STEPS = 1000
MAX_GRAD_NORM = 1.0
NUM_ITERATIONS = 1
MAX_TOKENS = 512
TEMPERATURE = 1.0
TOP_P = 1.0
SAVE_EVERY_STEPS = 50
LOGGING_STEPS = 1
REPORT_TO = ["wandb"]
LOG_COMPLETIONS = True
LOG_ON_EACH_NODE = False
ASYNC_GENERATION_TIMEOUT = 60000
MAX_CONCURRENT = 1024
WANDB_PROJECT = "14B-GRPO"
WANDB_NAME = RUN_NAME
if WANDB_PROJECT:
os.environ.setdefault("WANDB_PROJECT", WANDB_PROJECT)
if WANDB_NAME:
os.environ.setdefault("WANDB_NAME", WANDB_NAME)
def _check_reward_server() -> None:
client = RewardModelClient(
REWARD_MODEL,
server_type="lmdeploy",
server_address=SERVER_ADDRESS,
)
sanity_samples = [
{
"prompt": "What is the capital of China?",
"reference": "Beijing.",
"output": "Beijing.",
},
{
"prompt": "What is the capital of China?",
"reference": "Beijing.",
"output": "Shanghai.",
},
]
encoded = client.encode(sanity_samples)
rewards = client.lmdeploy_request_reward(encoded)
print("[sanity] lmdeploy rewards:", rewards)
_check_reward_server()
vf_env = vf.load_environment(
env_id="sharegpt-polar",
data_path=DATA_PATH,
server_address=SERVER_ADDRESS,
)
model, tokenizer = vf.get_model_and_tokenizer(MODEL_NAME)
training_args = vf.grpo_defaults(run_name=RUN_NAME)
training_args.per_device_train_batch_size = PER_DEVICE_TRAIN_BATCH_SIZE
training_args.num_generations = NUM_GENERATIONS
training_args.gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS
training_args.learning_rate = LEARNING_RATE
training_args.beta = BETA
training_args.max_steps = MAX_STEPS
training_args.max_grad_norm = MAX_GRAD_NORM
training_args.num_iterations = NUM_ITERATIONS
training_args.max_tokens = MAX_TOKENS
training_args.temperature = TEMPERATURE
training_args.top_p = TOP_P
training_args.save_strategy = "steps"
training_args.save_steps = SAVE_EVERY_STEPS
training_args.logging_steps = LOGGING_STEPS
training_args.report_to = REPORT_TO
training_args.log_completions = LOG_COMPLETIONS
training_args.log_on_each_node = LOG_ON_EACH_NODE
training_args.async_generation_timeout = ASYNC_GENERATION_TIMEOUT
training_args.max_concurrent = MAX_CONCURRENT
trainer = vf.GRPOTrainer(
env=vf_env,
model=model,
processing_class=tokenizer,
args=training_args,
peft_config=vf.lora_defaults(r=128, alpha=64),
)
if __name__ == "__main__":
trainer.train()