File size: 2,428 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
from unsloth import FastLanguageModel, PatchFastRL
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
import requests
import os
import sys
from pathlib import Path

# Patch TRL for Unsloth speedups
PatchFastRL("GRPO", FastLanguageModel)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--steps", type=int, default=500)
    args = parser.parse_args()

    # Optimized for L4 Speed
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit",
        max_seq_length=1024,
        load_in_4bit=True,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=16, # Increased rank for "harder" learning
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_alpha=32,
        use_gradient_checkpointing="unsloth",
    )
    if not hasattr(model, "warnings_issued"): model.warnings_issued = {}

    print("Fetching 500 hard samples...")
    train_samples = []
    for _ in range(500):
        r = requests.post("http://localhost:8000/reset")
        if r.status_code == 200:
            obs = r.json()["observation"]
            from commitguard_env.grpo_prompt import get_agent_prompt, SYSTEM_PROMPT
            prompt = get_agent_prompt(obs["diff"], obs["available_files"], 0)
            train_samples.append({"prompt": prompt, "system": SYSTEM_PROMPT})
    
    dataset = Dataset.from_list(train_samples)

    # HEAVY TRAINING CONFIG
    training_args = GRPOConfig(
        output_dir="outputs/commitguard-heavy",
        num_generations=4,
        max_completion_length=256,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8, # Effective batch size of 64
        learning_rate=2e-5, # Higher LR for fast weight shift
        max_steps=args.steps,
        bf16=True,
        logging_steps=1,
        save_steps=100,
        report_to="none",
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[lambda prompts, completions, **kwargs: [0.5]*len(completions)], # Place-holder for speed
        args=training_args,
        train_dataset=dataset,
    )

    print(f"Starting HEAVY training for {args.steps} steps...")
    trainer.train()
    model.save_pretrained_merged("outputs/commitguard-heavy/final", tokenizer, save_method="lora")

if __name__ == "__main__":
    main()