File size: 3,573 Bytes
03a7eb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import re
import argparse
from typing import Any

import httpx
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer


ENV_URL = "http://127.0.0.1:7860"
MODEL_NAME = "distilgpt2"


def _extract_text(completion: Any) -> str:
    if isinstance(completion, str):
        return completion
    if isinstance(completion, list):
        chunks = []
        for item in completion:
            if isinstance(item, dict) and "content" in item:
                chunks.append(str(item["content"]))
            else:
                chunks.append(str(item))
        return "\n".join(chunks)
    if isinstance(completion, dict):
        return str(completion.get("content", ""))
    return str(completion)


def _clean_fix(text: str) -> str:
    text = text.strip()
    text = re.sub(r"^```(?:python)?\s*", "", text)
    text = re.sub(r"\s*```$", "", text)
    return text.strip() or "pass"


def codearena_reward_func(completions, prompts, **kwargs):
    rewards = []
    with httpx.Client(timeout=60.0) as client:
        for completion in completions:
            proposed_fix = _clean_fix(_extract_text(completion))
            reward = 0.001
            for _ in range(2):
                try:
                    client.post(f"{ENV_URL}/reset", json={"task_id": "easy-1"})
                    res = client.post(
                        f"{ENV_URL}/step",
                        json={"proposed_fix": proposed_fix},
                    )
                    reward = float(res.json().get("reward", 0.001))
                    break
                except Exception:
                    reward = 0.001
            rewards.append(max(0.001, min(0.999, reward)))
    return rewards


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-steps", type=int, default=3)
    parser.add_argument("--output-dir", type=str, default="./grpo-check-output")
    args = parser.parse_args()

    prompts = [
        "Fix this Python function: def average_list(numbers)\\n    if length(numbers) == 0:\\n        return 0\\n    return sum(numbers) / length(numbers)",
        "Repair all root-cause issues in the function and keep readability high.",
        "Return a corrected Python function only. Ensure tests pass.",
        "Fix missing syntax and replace invalid APIs with valid Python APIs.",
        "Correct both compile and semantic issues in the provided function.",
        "Provide a secure, clean fix for average_list in Python.",
    ]
    train_dataset = Dataset.from_dict({"prompt": prompts})

    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    training_args = GRPOConfig(
        output_dir=args.output_dir,
        learning_rate=1e-5,
        max_steps=args.max_steps,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        logging_steps=1,
        num_generations=2,
        max_prompt_length=256,
        max_completion_length=96,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        shuffle_dataset=False,
        seed=42,
        bf16=False,
        fp16=False,
        report_to=[],
    )

    trainer = GRPOTrainer(
        model=model,
        reward_funcs=codearena_reward_func,
        args=training_args,
        train_dataset=train_dataset,
    )
    trainer.train()
    print("GRPO check finished.")


if __name__ == "__main__":
    main()