havinashpatil commited on
Commit
9204c04
·
1 Parent(s): 8599a81

feat: use m-a-p/Code-Feedback dataset for GRPO training

Browse files
Files changed (1) hide show
  1. train_grpo.ipynb +21 -9
train_grpo.ipynb CHANGED
@@ -6,7 +6,8 @@
6
  "source": [
7
  "# GRPO Training with CodeArena RL Benchmark\n",
8
  "\n",
9
- "This notebook demonstrates how to connect our custom `codearena-rl-benchmark` environment to HuggingFace's `trl.GRPOTrainer`."
 
10
  ]
11
  },
12
  {
@@ -27,7 +28,7 @@
27
  "outputs": [],
28
  "source": [
29
  "import torch\n",
30
- "from datasets import Dataset\n",
31
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
32
  "from trl import GRPOConfig, GRPOTrainer\n",
33
  "import httpx\n",
@@ -86,13 +87,24 @@
86
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
87
  "tokenizer.pad_token = tokenizer.eos_token\n",
88
  "\n",
89
- "# Sample training dataset (prompts extracted from tasks)\n",
90
- "# In a real setup, you'd reset the env for each prompt to get the initial buggy_code.\n",
91
- "dataset = Dataset.from_dict({\n",
92
- " \"prompt\": [\n",
93
- " \"Fix this Python code:\\ndef average_list(numbers)\\n if length(numbers) == 0:\\n return 0\\n return sum(numbers) / length(numbers)\"\n",
94
- " ]\n",
95
- "})\n",
 
 
 
 
 
 
 
 
 
 
 
96
  "\n",
97
  "# Initialize GRPO Trainer\n",
98
  "training_args = GRPOConfig(\n",
 
6
  "source": [
7
  "# GRPO Training with CodeArena RL Benchmark\n",
8
  "\n",
9
+ "This notebook demonstrates how to connect our custom `codearena-rl-benchmark` environment to HuggingFace's `trl.GRPOTrainer`.\n",
10
+ "It uses the `m-a-p/Code-Feedback` dataset to train the LLM for coding debugging and improving time complexity."
11
  ]
12
  },
13
  {
 
28
  "outputs": [],
29
  "source": [
30
  "import torch\n",
31
+ "from datasets import load_dataset\n",
32
  "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
33
  "from trl import GRPOConfig, GRPOTrainer\n",
34
  "import httpx\n",
 
87
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
88
  "tokenizer.pad_token = tokenizer.eos_token\n",
89
  "\n",
90
+ "# Load dataset for Coding Debugging and Time Complexity Optimization\n",
91
+ "dataset = load_dataset(\"m-a-p/Code-Feedback\", split=\"train\")\n",
92
+ "\n",
93
+ "def format_prompt(example):\n",
94
+ " # m-a-p/Code-Feedback contains 'messages' with user and assistant roles\n",
95
+ " messages = example.get('messages', [])\n",
96
+ " user_query = \"\"\n",
97
+ " if messages and len(messages) > 0 and messages[0].get('role') == 'user':\n",
98
+ " user_query = messages[0].get('content', '')\n",
99
+ " \n",
100
+ " prompt = f\"Optimize and debug this code to improve time complexity:\\n{user_query}\"\n",
101
+ " return {\"prompt\": prompt}\n",
102
+ "\n",
103
+ "dataset = dataset.map(format_prompt)\n",
104
+ "# Keep only the prompt column for the trainer\n",
105
+ "dataset = dataset.select_columns([\"prompt\"])\n",
106
+ "# Limit for demo purposes\n",
107
+ "dataset = dataset.select(range(100))\n",
108
  "\n",
109
  "# Initialize GRPO Trainer\n",
110
  "training_args = GRPOConfig(\n",