Spaces:
Running
Running
Commit ·
9b2abc6
1
Parent(s): 9a9721a
Fix MAX_SEQ_LENGTH: 1024 was too small for prompt+completion, bump to 2048
Browse files- training/train_grpo.ipynb +1 -21
training/train_grpo.ipynb
CHANGED
|
@@ -128,27 +128,7 @@
|
|
| 128 |
"execution_count": null,
|
| 129 |
"metadata": {},
|
| 130 |
"outputs": [],
|
| 131 |
-
"source": [
|
| 132 |
-
"import torch\n",
|
| 133 |
-
"\n",
|
| 134 |
-
"MAX_STEPS = 300\n",
|
| 135 |
-
"MAX_SEQ_LENGTH = 1024\n",
|
| 136 |
-
"LORA_RANK = 32\n",
|
| 137 |
-
"\n",
|
| 138 |
-
"gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"cpu\"\n",
|
| 139 |
-
"print(f\"GPU: {gpu_name}\")\n",
|
| 140 |
-
"\n",
|
| 141 |
-
"if any(x in gpu_name.upper() for x in [\"H100\",\"A100\",\"L40\",\"4090\"]):\n",
|
| 142 |
-
" MODEL_NAME = \"unsloth/Qwen3-4B\"\n",
|
| 143 |
-
" LOAD_IN_4BIT = False; FAST_INFERENCE = True; NUM_GENERATIONS = 4; LR = 5e-5\n",
|
| 144 |
-
" print(\"Config: Qwen3-4B instruct, BF16 + vLLM\")\n",
|
| 145 |
-
"else:\n",
|
| 146 |
-
" MODEL_NAME = \"unsloth/Qwen3-4B-unsloth-bnb-4bit\"\n",
|
| 147 |
-
" LOAD_IN_4BIT = True; FAST_INFERENCE = False; NUM_GENERATIONS = 2; LR = 2e-4\n",
|
| 148 |
-
" print(\"Config: Qwen3-4B instruct, 4-bit (T4)\")\n",
|
| 149 |
-
"\n",
|
| 150 |
-
"print(f\"Model: {MODEL_NAME} | Steps: {MAX_STEPS} | Gens: {NUM_GENERATIONS}\")"
|
| 151 |
-
]
|
| 152 |
},
|
| 153 |
{
|
| 154 |
"cell_type": "markdown",
|
|
|
|
| 128 |
"execution_count": null,
|
| 129 |
"metadata": {},
|
| 130 |
"outputs": [],
|
| 131 |
+
"source": "import torch\n\nMAX_STEPS = 300\nMAX_SEQ_LENGTH = 2048 # Must be >= max_prompt_length + max_completion_length (512+1024)\nLORA_RANK = 32\n\ngpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"cpu\"\nprint(f\"GPU: {gpu_name}\")\n\nif any(x in gpu_name.upper() for x in [\"H100\",\"A100\",\"L40\",\"4090\"]):\n MODEL_NAME = \"unsloth/Qwen3-4B\"\n LOAD_IN_4BIT = False; FAST_INFERENCE = True; NUM_GENERATIONS = 4; LR = 5e-5\n print(\"Config: Qwen3-4B instruct, BF16 + vLLM\")\nelse:\n MODEL_NAME = \"unsloth/Qwen3-4B-unsloth-bnb-4bit\"\n LOAD_IN_4BIT = True; FAST_INFERENCE = False; NUM_GENERATIONS = 2; LR = 2e-4\n print(\"Config: Qwen3-4B instruct, 4-bit (T4)\")\n\nprint(f\"Model: {MODEL_NAME} | Steps: {MAX_STEPS} | Gens: {NUM_GENERATIONS} | SeqLen: {MAX_SEQ_LENGTH}\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "markdown",
|