sissississi Claude Opus 4.6 commited on
Commit
9b2abc6
·
1 Parent(s): 9a9721a

Fix MAX_SEQ_LENGTH: 1024 was too small for prompt+completion, bump to 2048

Browse files
Files changed (1) hide show
  1. training/train_grpo.ipynb +1 -21
training/train_grpo.ipynb CHANGED
@@ -128,27 +128,7 @@
128
  "execution_count": null,
129
  "metadata": {},
130
  "outputs": [],
131
- "source": [
132
- "import torch\n",
133
- "\n",
134
- "MAX_STEPS = 300\n",
135
- "MAX_SEQ_LENGTH = 1024\n",
136
- "LORA_RANK = 32\n",
137
- "\n",
138
- "gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"cpu\"\n",
139
- "print(f\"GPU: {gpu_name}\")\n",
140
- "\n",
141
- "if any(x in gpu_name.upper() for x in [\"H100\",\"A100\",\"L40\",\"4090\"]):\n",
142
- " MODEL_NAME = \"unsloth/Qwen3-4B\"\n",
143
- " LOAD_IN_4BIT = False; FAST_INFERENCE = True; NUM_GENERATIONS = 4; LR = 5e-5\n",
144
- " print(\"Config: Qwen3-4B instruct, BF16 + vLLM\")\n",
145
- "else:\n",
146
- " MODEL_NAME = \"unsloth/Qwen3-4B-unsloth-bnb-4bit\"\n",
147
- " LOAD_IN_4BIT = True; FAST_INFERENCE = False; NUM_GENERATIONS = 2; LR = 2e-4\n",
148
- " print(\"Config: Qwen3-4B instruct, 4-bit (T4)\")\n",
149
- "\n",
150
- "print(f\"Model: {MODEL_NAME} | Steps: {MAX_STEPS} | Gens: {NUM_GENERATIONS}\")"
151
- ]
152
  },
153
  {
154
  "cell_type": "markdown",
 
128
  "execution_count": null,
129
  "metadata": {},
130
  "outputs": [],
131
+ "source": "import torch\n\nMAX_STEPS = 300\nMAX_SEQ_LENGTH = 2048 # Must be >= max_prompt_length + max_completion_length (512+1024)\nLORA_RANK = 32\n\ngpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"cpu\"\nprint(f\"GPU: {gpu_name}\")\n\nif any(x in gpu_name.upper() for x in [\"H100\",\"A100\",\"L40\",\"4090\"]):\n MODEL_NAME = \"unsloth/Qwen3-4B\"\n LOAD_IN_4BIT = False; FAST_INFERENCE = True; NUM_GENERATIONS = 4; LR = 5e-5\n print(\"Config: Qwen3-4B instruct, BF16 + vLLM\")\nelse:\n MODEL_NAME = \"unsloth/Qwen3-4B-unsloth-bnb-4bit\"\n LOAD_IN_4BIT = True; FAST_INFERENCE = False; NUM_GENERATIONS = 2; LR = 2e-4\n print(\"Config: Qwen3-4B instruct, 4-bit (T4)\")\n\nprint(f\"Model: {MODEL_NAME} | Steps: {MAX_STEPS} | Gens: {NUM_GENERATIONS} | SeqLen: {MAX_SEQ_LENGTH}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  },
133
  {
134
  "cell_type": "markdown",