Spaces:
Sleeping
Sleeping
havinashpatil commited on
Commit ·
9204c04
1
Parent(s): 8599a81
feat: use m-a-p/Code-Feedback dataset for GRPO training
Browse files- train_grpo.ipynb +21 -9
train_grpo.ipynb
CHANGED
|
@@ -6,7 +6,8 @@
|
|
| 6 |
"source": [
|
| 7 |
"# GRPO Training with CodeArena RL Benchmark\n",
|
| 8 |
"\n",
|
| 9 |
-
"This notebook demonstrates how to connect our custom `codearena-rl-benchmark` environment to HuggingFace's `trl.GRPOTrainer`."
|
|
|
|
| 10 |
]
|
| 11 |
},
|
| 12 |
{
|
|
@@ -27,7 +28,7 @@
|
|
| 27 |
"outputs": [],
|
| 28 |
"source": [
|
| 29 |
"import torch\n",
|
| 30 |
-
"from datasets import
|
| 31 |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 32 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 33 |
"import httpx\n",
|
|
@@ -86,13 +87,24 @@
|
|
| 86 |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 87 |
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 88 |
"\n",
|
| 89 |
-
"#
|
| 90 |
-
"
|
| 91 |
-
"
|
| 92 |
-
"
|
| 93 |
-
"
|
| 94 |
-
" ]\n",
|
| 95 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"\n",
|
| 97 |
"# Initialize GRPO Trainer\n",
|
| 98 |
"training_args = GRPOConfig(\n",
|
|
|
|
| 6 |
"source": [
|
| 7 |
"# GRPO Training with CodeArena RL Benchmark\n",
|
| 8 |
"\n",
|
| 9 |
+
"This notebook demonstrates how to connect our custom `codearena-rl-benchmark` environment to HuggingFace's `trl.GRPOTrainer`.\n",
|
| 10 |
+
"It uses the `m-a-p/Code-Feedback` dataset to train the LLM for coding debugging and improving time complexity."
|
| 11 |
]
|
| 12 |
},
|
| 13 |
{
|
|
|
|
| 28 |
"outputs": [],
|
| 29 |
"source": [
|
| 30 |
"import torch\n",
|
| 31 |
+
"from datasets import load_dataset\n",
|
| 32 |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 33 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 34 |
"import httpx\n",
|
|
|
|
| 87 |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 88 |
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 89 |
"\n",
|
| 90 |
+
"# Load dataset for Coding Debugging and Time Complexity Optimization\n",
|
| 91 |
+
"dataset = load_dataset(\"m-a-p/Code-Feedback\", split=\"train\")\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"def format_prompt(example):\n",
|
| 94 |
+
" # m-a-p/Code-Feedback contains 'messages' with user and assistant roles\n",
|
| 95 |
+
" messages = example.get('messages', [])\n",
|
| 96 |
+
" user_query = \"\"\n",
|
| 97 |
+
" if messages and len(messages) > 0 and messages[0].get('role') == 'user':\n",
|
| 98 |
+
" user_query = messages[0].get('content', '')\n",
|
| 99 |
+
" \n",
|
| 100 |
+
" prompt = f\"Optimize and debug this code to improve time complexity:\\n{user_query}\"\n",
|
| 101 |
+
" return {\"prompt\": prompt}\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"dataset = dataset.map(format_prompt)\n",
|
| 104 |
+
"# Keep only the prompt column for the trainer\n",
|
| 105 |
+
"dataset = dataset.select_columns([\"prompt\"])\n",
|
| 106 |
+
"# Limit for demo purposes\n",
|
| 107 |
+
"dataset = dataset.select(range(100))\n",
|
| 108 |
"\n",
|
| 109 |
"# Initialize GRPO Trainer\n",
|
| 110 |
"training_args = GRPOConfig(\n",
|