commitguard-env / scripts /train_on_hf.py
Nitishkumar-ai's picture
Deployment Build (Final): Professional Structure + Blog
95cbc5b
import argparse
import os
import sys
import requests
from pathlib import Path
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import FastLanguageModel, PatchFastRL
from huggingface_hub import login
# Add project root to sys.path
root_dir = Path(__file__).resolve().parent.parent
if str(root_dir) not in sys.path:
sys.path.append(str(root_dir))
from commitguard_env.grpo_prompt import SYSTEM_PROMPT, get_agent_prompt
# Patch TRL for Unsloth speedups
PatchFastRL("GRPO", FastLanguageModel)
def get_reward_from_env_base(env_url):
def reward_fn(prompts, completions, **kwargs) -> list[float]:
rewards = []
for completion in completions:
try:
payload = {"action": completion}
r = requests.post(f"{env_url}/step", json=payload, timeout=15)
if r.status_code == 200:
rewards.append(float(r.json().get("reward", 0.0)))
else:
rewards.append(-0.5)
except Exception:
rewards.append(-1.0)
return rewards
return reward_fn
def main():
parser = argparse.ArgumentParser(description="CommitGuard GRPO Trainer for Hugging Face Hub Jobs")
parser.add_argument("--model_name", type=str, default=os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct"))
parser.add_argument("--output_dir", type=str, default="outputs/commitguard-hf")
parser.add_argument("--steps", type=int, default=int(os.getenv("STEPS", "500")))
parser.add_argument("--env_url", type=str, default=os.getenv("ENV_URL", "http://localhost:8000"))
parser.add_argument("--hf_repo", type=str, default=os.getenv("HF_REPO"))
parser.add_argument("--wandb", type=str, default=os.getenv("WANDB_PROJECT", "commitguard-rlvr"))
parser.add_argument("--num_generations", type=int, default=int(os.getenv("NUM_GENERATIONS", "4")))
args = parser.parse_args()
# 0. Auth
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
if args.wandb:
os.environ["WANDB_PROJECT"] = args.wandb
if os.getenv("WANDB_API_KEY"):
import wandb
wandb.login(key=os.getenv("WANDB_API_KEY"))
print(f"--- Training Config ---")
print(f"Model: {args.model_name}")
print(f"Steps: {args.steps}")
print(f"Env URL: {args.env_url}")
print(f"HF Repo: {args.hf_repo}")
print(f"-----------------------")
# 1. Load Model and Tokenizer with Unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=2048,
load_in_4bit=True,
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)
if not hasattr(model, "warnings_issued"):
model.warnings_issued = {}
# 2. Prepare Dataset from Environment
print(f"Fetching {args.steps} samples from environment...")
train_samples = []
# Fetching in bulk might be faster, but let's stick to the current logic for compatibility
for _ in range(min(args.steps, 1000)):
try:
r = requests.post(f"{args.env_url}/reset", timeout=10)
if r.status_code == 200:
obs = r.json()["observation"]
prompt = get_agent_prompt(obs["diff"], obs["available_files"], obs["step_idx"])
train_samples.append({"prompt": prompt, "system": SYSTEM_PROMPT})
except Exception as e:
print(f"Warning: Failed to fetch sample: {e}")
break
if not train_samples:
print("Error: No training samples fetched. Check ENV_URL.")
sys.exit(1)
dataset = Dataset.from_list(train_samples)
# 3. Configure GRPO
training_args = GRPOConfig(
output_dir=args.output_dir,
num_generations=args.num_generations,
max_completion_length=512,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-6,
logging_steps=1,
save_steps=100,
max_steps=args.steps,
report_to="wandb" if os.getenv("WANDB_API_KEY") else "none",
bf16=True,
push_to_hub=True if args.hf_repo else False,
hub_model_id=args.hf_repo,
hub_strategy="end",
)
# 4. Initialize Trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[get_reward_from_env_base(args.env_url)],
args=training_args,
train_dataset=dataset,
)
# 5. Launch Training
print("Starting GRPO Training...")
trainer.train()
# 6. Final Push
if args.hf_repo:
print(f"Pushing final adapter to {args.hf_repo}...")
model.push_to_hub(args.hf_repo, token=hf_token)
tokenizer.push_to_hub(args.hf_repo, token=hf_token)
else:
final_path = os.path.join(args.output_dir, "final")
model.save_pretrained_merged(final_path, tokenizer, save_method="lora")
print(f"Saved locally to {final_path}")
if __name__ == "__main__":
main()