Spaces:
Running
Running
File size: 5,342 Bytes
95cbc5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import argparse
import os
import sys
import requests
from pathlib import Path
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel, PatchFastRL
from huggingface_hub import login
# Add project root to sys.path
root_dir = Path(__file__).resolve().parent.parent
if str(root_dir) not in sys.path:
sys.path.append(str(root_dir))
from commitguard_env.grpo_prompt import SYSTEM_PROMPT, get_agent_prompt
# Patch TRL for Unsloth speedups
PatchFastRL("GRPO", FastLanguageModel)
def get_reward_from_env_base(env_url):
def reward_fn(prompts, completions, **kwargs) -> list[float]:
rewards = []
for completion in completions:
try:
payload = {"action": completion}
r = requests.post(f"{env_url}/step", json=payload, timeout=15)
if r.status_code == 200:
rewards.append(float(r.json().get("reward", 0.0)))
else:
rewards.append(-0.5)
except Exception:
rewards.append(-1.0)
return rewards
return reward_fn
def main():
parser = argparse.ArgumentParser(description="CommitGuard GRPO Trainer for Hugging Face Hub Jobs")
parser.add_argument("--model_name", type=str, default=os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct"))
parser.add_argument("--output_dir", type=str, default="outputs/commitguard-hf")
parser.add_argument("--steps", type=int, default=int(os.getenv("STEPS", "500")))
parser.add_argument("--env_url", type=str, default=os.getenv("ENV_URL", "http://localhost:8000"))
parser.add_argument("--hf_repo", type=str, default=os.getenv("HF_REPO"))
parser.add_argument("--wandb", type=str, default=os.getenv("WANDB_PROJECT", "commitguard-rlvr"))
parser.add_argument("--num_generations", type=int, default=int(os.getenv("NUM_GENERATIONS", "4")))
args = parser.parse_args()
# 0. Auth
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
if args.wandb:
os.environ["WANDB_PROJECT"] = args.wandb
if os.getenv("WANDB_API_KEY"):
import wandb
wandb.login(key=os.getenv("WANDB_API_KEY"))
print(f"--- Training Config ---")
print(f"Model: {args.model_name}")
print(f"Steps: {args.steps}")
print(f"Env URL: {args.env_url}")
print(f"HF Repo: {args.hf_repo}")
print(f"-----------------------")
# 1. Load Model and Tokenizer with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=2048,
load_in_4bit=True,
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)
if not hasattr(model, "warnings_issued"):
model.warnings_issued = {}
# 2. Prepare Dataset from Environment
print(f"Fetching {args.steps} samples from environment...")
train_samples = []
# Fetching in bulk might be faster, but let's stick to the current logic for compatibility
for _ in range(min(args.steps, 1000)):
try:
r = requests.post(f"{args.env_url}/reset", timeout=10)
if r.status_code == 200:
obs = r.json()["observation"]
prompt = get_agent_prompt(obs["diff"], obs["available_files"], obs["step_idx"])
train_samples.append({"prompt": prompt, "system": SYSTEM_PROMPT})
except Exception as e:
print(f"Warning: Failed to fetch sample: {e}")
break
if not train_samples:
print("Error: No training samples fetched. Check ENV_URL.")
sys.exit(1)
dataset = Dataset.from_list(train_samples)
# 3. Configure GRPO
training_args = GRPOConfig(
output_dir=args.output_dir,
num_generations=args.num_generations,
max_completion_length=512,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-6,
logging_steps=1,
save_steps=100,
max_steps=args.steps,
report_to="wandb" if os.getenv("WANDB_API_KEY") else "none",
bf16=True,
push_to_hub=True if args.hf_repo else False,
hub_model_id=args.hf_repo,
hub_strategy="end",
)
# 4. Initialize Trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[get_reward_from_env_base(args.env_url)],
args=training_args,
train_dataset=dataset,
)
# 5. Launch Training
print("Starting GRPO Training...")
trainer.train()
# 6. Final Push
if args.hf_repo:
print(f"Pushing final adapter to {args.hf_repo}...")
model.push_to_hub(args.hf_repo, token=hf_token)
tokenizer.push_to_hub(args.hf_repo, token=hf_token)
else:
final_path = os.path.join(args.output_dir, "final")
model.save_pretrained_merged(final_path, tokenizer, save_method="lora")
print(f"Saved locally to {final_path}")
if __name__ == "__main__":
main()
|