Spaces:
Running
Running
| import argparse | |
| import os | |
| import sys | |
| import requests | |
| from pathlib import Path | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from unsloth import FastLanguageModel, PatchFastRL | |
| from huggingface_hub import login | |
| # Add project root to sys.path | |
| root_dir = Path(__file__).resolve().parent.parent | |
| if str(root_dir) not in sys.path: | |
| sys.path.append(str(root_dir)) | |
| from commitguard_env.grpo_prompt import SYSTEM_PROMPT, get_agent_prompt | |
| # Patch TRL for Unsloth speedups | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| def get_reward_from_env_base(env_url): | |
| def reward_fn(prompts, completions, **kwargs) -> list[float]: | |
| rewards = [] | |
| for completion in completions: | |
| try: | |
| payload = {"action": completion} | |
| r = requests.post(f"{env_url}/step", json=payload, timeout=15) | |
| if r.status_code == 200: | |
| rewards.append(float(r.json().get("reward", 0.0))) | |
| else: | |
| rewards.append(-0.5) | |
| except Exception: | |
| rewards.append(-1.0) | |
| return rewards | |
| return reward_fn | |
| def main(): | |
| parser = argparse.ArgumentParser(description="CommitGuard GRPO Trainer for Hugging Face Hub Jobs") | |
| parser.add_argument("--model_name", type=str, default=os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")) | |
| parser.add_argument("--output_dir", type=str, default="outputs/commitguard-hf") | |
| parser.add_argument("--steps", type=int, default=int(os.getenv("STEPS", "500"))) | |
| parser.add_argument("--env_url", type=str, default=os.getenv("ENV_URL", "http://localhost:8000")) | |
| parser.add_argument("--hf_repo", type=str, default=os.getenv("HF_REPO")) | |
| parser.add_argument("--wandb", type=str, default=os.getenv("WANDB_PROJECT", "commitguard-rlvr")) | |
| parser.add_argument("--num_generations", type=int, default=int(os.getenv("NUM_GENERATIONS", "4"))) | |
| args = parser.parse_args() | |
| # 0. Auth | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| if args.wandb: | |
| os.environ["WANDB_PROJECT"] = args.wandb | |
| if os.getenv("WANDB_API_KEY"): | |
| import wandb | |
| wandb.login(key=os.getenv("WANDB_API_KEY")) | |
| print(f"--- Training Config ---") | |
| print(f"Model: {args.model_name}") | |
| print(f"Steps: {args.steps}") | |
| print(f"Env URL: {args.env_url}") | |
| print(f"HF Repo: {args.hf_repo}") | |
| print(f"-----------------------") | |
| # 1. Load Model and Tokenizer with Unsloth | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=args.model_name, | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| fast_inference=False, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=8, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| if not hasattr(model, "warnings_issued"): | |
| model.warnings_issued = {} | |
| # 2. Prepare Dataset from Environment | |
| print(f"Fetching {args.steps} samples from environment...") | |
| train_samples = [] | |
| # Fetching in bulk might be faster, but let's stick to the current logic for compatibility | |
| for _ in range(min(args.steps, 1000)): | |
| try: | |
| r = requests.post(f"{args.env_url}/reset", timeout=10) | |
| if r.status_code == 200: | |
| obs = r.json()["observation"] | |
| prompt = get_agent_prompt(obs["diff"], obs["available_files"], obs["step_idx"]) | |
| train_samples.append({"prompt": prompt, "system": SYSTEM_PROMPT}) | |
| except Exception as e: | |
| print(f"Warning: Failed to fetch sample: {e}") | |
| break | |
| if not train_samples: | |
| print("Error: No training samples fetched. Check ENV_URL.") | |
| sys.exit(1) | |
| dataset = Dataset.from_list(train_samples) | |
| # 3. Configure GRPO | |
| training_args = GRPOConfig( | |
| output_dir=args.output_dir, | |
| num_generations=args.num_generations, | |
| max_completion_length=512, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| learning_rate=5e-6, | |
| logging_steps=1, | |
| save_steps=100, | |
| max_steps=args.steps, | |
| report_to="wandb" if os.getenv("WANDB_API_KEY") else "none", | |
| bf16=True, | |
| push_to_hub=True if args.hf_repo else False, | |
| hub_model_id=args.hf_repo, | |
| hub_strategy="end", | |
| ) | |
| # 4. Initialize Trainer | |
| trainer = GRPOTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[get_reward_from_env_base(args.env_url)], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| # 5. Launch Training | |
| print("Starting GRPO Training...") | |
| trainer.train() | |
| # 6. Final Push | |
| if args.hf_repo: | |
| print(f"Pushing final adapter to {args.hf_repo}...") | |
| model.push_to_hub(args.hf_repo, token=hf_token) | |
| tokenizer.push_to_hub(args.hf_repo, token=hf_token) | |
| else: | |
| final_path = os.path.join(args.output_dir, "final") | |
| model.save_pretrained_merged(final_path, tokenizer, save_method="lora") | |
| print(f"Saved locally to {final_path}") | |
| if __name__ == "__main__": | |
| main() | |