parth-1 commited on
Commit
c16c504
·
verified ·
1 Parent(s): 4ae43fc

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +20 -13
grpo_train.py CHANGED
@@ -300,24 +300,31 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
300
  # MODEL
301
  # =========================
302
 
303
- USE_4BIT = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 40 * 1024**3
 
 
 
 
 
 
 
 
304
 
305
  model, tokenizer = FastLanguageModel.from_pretrained(
306
  model_name="unsloth/Llama-3.1-8B-Instruct",
307
  load_in_4bit=USE_4BIT,
308
- dtype = torch.bfloat16,
309
  max_seq_length=2048,
310
- dtype=None, # auto-detect bf16 on A100
311
  )
312
 
313
  model = FastLanguageModel.get_peft_model(
314
  model,
315
- r=32,
316
  target_modules=[
317
  "q_proj", "k_proj", "v_proj", "o_proj",
318
  "gate_proj", "up_proj", "down_proj",
319
  ],
320
- lora_alpha=64,
321
  lora_dropout=0,
322
  bias="none",
323
  use_gradient_checkpointing="unsloth",
@@ -336,16 +343,16 @@ trainer = GRPOTrainer(
336
  args=GRPOConfig(
337
  output_dir="outputs",
338
  learning_rate=2e-5,
339
- num_train_epochs=3,
340
- per_device_train_batch_size=2,
341
- gradient_accumulation_steps=4,
342
- num_generations=4,
343
  max_prompt_length=768,
344
  max_completion_length=128,
345
- logging_steps=5,
346
- warmup_ratio=0.1,
347
- bf16=True,
348
- fp16=False,
349
  report_to="none",
350
  ),
351
  train_dataset=dataset,
 
300
  # MODEL
301
  # =========================
302
 
303
+ if torch.cuda.is_available():
304
+ _vram = torch.cuda.get_device_properties(0).total_memory
305
+ _name = torch.cuda.get_device_name(0)
306
+ print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB")
307
+ else:
308
+ _vram = 0
309
+ _name = "CPU"
310
+
311
+ USE_4BIT = _vram < 40 * 1024**3 # True for T4 (15 GB) and L4 (24 GB); False for A100 (80 GB)
312
 
313
  model, tokenizer = FastLanguageModel.from_pretrained(
314
  model_name="unsloth/Llama-3.1-8B-Instruct",
315
  load_in_4bit=USE_4BIT,
 
316
  max_seq_length=2048,
317
+ dtype=None,
318
  )
319
 
320
  model = FastLanguageModel.get_peft_model(
321
  model,
322
+ r=16 if USE_4BIT else 32,
323
  target_modules=[
324
  "q_proj", "k_proj", "v_proj", "o_proj",
325
  "gate_proj", "up_proj", "down_proj",
326
  ],
327
+ lora_alpha=32 if USE_4BIT else 64,
328
  lora_dropout=0,
329
  bias="none",
330
  use_gradient_checkpointing="unsloth",
 
343
  args=GRPOConfig(
344
  output_dir="outputs",
345
  learning_rate=2e-5,
346
+ num_train_epochs=1 if USE_4BIT else 3,
347
+ per_device_train_batch_size=1 if USE_4BIT else 2,
348
+ gradient_accumulation_steps=2 if USE_4BIT else 4,
349
+ num_generations=2 if USE_4BIT else 4,
350
  max_prompt_length=768,
351
  max_completion_length=128,
352
+ logging_steps=3 if USE_4BIT else 5,
353
+ warmup_steps=5 if USE_4BIT else 10,
354
+ bf16=not USE_4BIT,
355
+ fp16=USE_4BIT,
356
  report_to="none",
357
  ),
358
  train_dataset=dataset,