Spaces:
Runtime error
Runtime error
Update grpo_train.py
Browse files- grpo_train.py +20 -13
grpo_train.py
CHANGED
|
@@ -300,24 +300,31 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
|
|
| 300 |
# MODEL
|
| 301 |
# =========================
|
| 302 |
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 306 |
model_name="unsloth/Llama-3.1-8B-Instruct",
|
| 307 |
load_in_4bit=USE_4BIT,
|
| 308 |
-
dtype = torch.bfloat16,
|
| 309 |
max_seq_length=2048,
|
| 310 |
-
dtype=None,
|
| 311 |
)
|
| 312 |
|
| 313 |
model = FastLanguageModel.get_peft_model(
|
| 314 |
model,
|
| 315 |
-
r=32,
|
| 316 |
target_modules=[
|
| 317 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 318 |
"gate_proj", "up_proj", "down_proj",
|
| 319 |
],
|
| 320 |
-
lora_alpha=64,
|
| 321 |
lora_dropout=0,
|
| 322 |
bias="none",
|
| 323 |
use_gradient_checkpointing="unsloth",
|
|
@@ -336,16 +343,16 @@ trainer = GRPOTrainer(
|
|
| 336 |
args=GRPOConfig(
|
| 337 |
output_dir="outputs",
|
| 338 |
learning_rate=2e-5,
|
| 339 |
-
num_train_epochs=3,
|
| 340 |
-
per_device_train_batch_size=2,
|
| 341 |
-
gradient_accumulation_steps=4,
|
| 342 |
-
num_generations=4,
|
| 343 |
max_prompt_length=768,
|
| 344 |
max_completion_length=128,
|
| 345 |
-
logging_steps=5,
|
| 346 |
-
|
| 347 |
-
bf16=
|
| 348 |
-
fp16=
|
| 349 |
report_to="none",
|
| 350 |
),
|
| 351 |
train_dataset=dataset,
|
|
|
|
| 300 |
# MODEL
|
| 301 |
# =========================
|
| 302 |
|
| 303 |
+
if torch.cuda.is_available():
|
| 304 |
+
_vram = torch.cuda.get_device_properties(0).total_memory
|
| 305 |
+
_name = torch.cuda.get_device_name(0)
|
| 306 |
+
print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB")
|
| 307 |
+
else:
|
| 308 |
+
_vram = 0
|
| 309 |
+
_name = "CPU"
|
| 310 |
+
|
| 311 |
+
USE_4BIT = _vram < 40 * 1024**3 # True for T4 (15 GB) and L4 (24 GB); False for A100 (80 GB)
|
| 312 |
|
| 313 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 314 |
model_name="unsloth/Llama-3.1-8B-Instruct",
|
| 315 |
load_in_4bit=USE_4BIT,
|
|
|
|
| 316 |
max_seq_length=2048,
|
| 317 |
+
dtype=None,
|
| 318 |
)
|
| 319 |
|
| 320 |
model = FastLanguageModel.get_peft_model(
|
| 321 |
model,
|
| 322 |
+
r=16 if USE_4BIT else 32,
|
| 323 |
target_modules=[
|
| 324 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 325 |
"gate_proj", "up_proj", "down_proj",
|
| 326 |
],
|
| 327 |
+
lora_alpha=32 if USE_4BIT else 64,
|
| 328 |
lora_dropout=0,
|
| 329 |
bias="none",
|
| 330 |
use_gradient_checkpointing="unsloth",
|
|
|
|
| 343 |
args=GRPOConfig(
|
| 344 |
output_dir="outputs",
|
| 345 |
learning_rate=2e-5,
|
| 346 |
+
num_train_epochs=1 if USE_4BIT else 3,
|
| 347 |
+
per_device_train_batch_size=1 if USE_4BIT else 2,
|
| 348 |
+
gradient_accumulation_steps=2 if USE_4BIT else 4,
|
| 349 |
+
num_generations=2 if USE_4BIT else 4,
|
| 350 |
max_prompt_length=768,
|
| 351 |
max_completion_length=128,
|
| 352 |
+
logging_steps=3 if USE_4BIT else 5,
|
| 353 |
+
warmup_steps=5 if USE_4BIT else 10,
|
| 354 |
+
bf16=not USE_4BIT,
|
| 355 |
+
fp16=USE_4BIT,
|
| 356 |
report_to="none",
|
| 357 |
),
|
| 358 |
train_dataset=dataset,
|