mpnikhil commited on
Commit
ccefb27
Β·
verified Β·
1 Parent(s): e3f21c2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. train_demo.py +18 -11
train_demo.py CHANGED
@@ -13,6 +13,7 @@ import hashlib
13
  import re
14
  import os
15
 
 
16
  from datasets import Dataset
17
  from trl import GRPOConfig, GRPOTrainer
18
  from trl.experimental.openenv import generate_rollout_completions
@@ -23,7 +24,7 @@ from skill_invocation_env.client import SkillInvocationEnv
23
  from skill_invocation_env.models import SkillInvocationAction
24
 
25
  # ── Configuration ──────────────────────────────────────────────────────────────
26
- MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-7B-Instruct")
27
  ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
28
  HF_TOKEN = os.getenv("HF_TOKEN")
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
@@ -264,12 +265,18 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
264
  seed = _extract_seed(prompt_text)
265
 
266
  env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
267
- episode = rollout_once(
268
- trainer=trainer,
269
- env=env,
270
- tokenizer=tokenizer,
271
- env_seed=seed,
272
- )
 
 
 
 
 
 
273
  all_prompt_ids.append(episode["prompt_ids"])
274
  all_completion_ids.append(episode["completion_ids"])
275
  all_logprobs.append(episode["logprobs"])
@@ -363,12 +370,12 @@ if __name__ == "__main__":
363
  output_dir=OUTPUT_DIR,
364
  use_vllm=True,
365
  vllm_mode="colocate",
366
- vllm_gpu_memory_utilization=0.6,
367
  num_train_epochs=1,
368
  num_generations=NUM_GENERATIONS,
369
- max_completion_length=MAX_COMPLETION_LENGTH,
370
- per_device_train_batch_size=8,
371
- gradient_accumulation_steps=4,
372
  learning_rate=1e-6,
373
  logging_steps=1,
374
  save_steps=50,
 
13
  import re
14
  import os
15
 
16
+ import wandb
17
  from datasets import Dataset
18
  from trl import GRPOConfig, GRPOTrainer
19
  from trl.experimental.openenv import generate_rollout_completions
 
24
  from skill_invocation_env.models import SkillInvocationAction
25
 
26
  # ── Configuration ──────────────────────────────────────────────────────────────
27
+ MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
28
  ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
29
  HF_TOKEN = os.getenv("HF_TOKEN")
30
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
 
265
  seed = _extract_seed(prompt_text)
266
 
267
  env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
268
+ try:
269
+ episode = rollout_once(
270
+ trainer=trainer,
271
+ env=env,
272
+ tokenizer=tokenizer,
273
+ env_seed=seed,
274
+ )
275
+ finally:
276
+ try:
277
+ env.close()
278
+ except Exception:
279
+ pass
280
  all_prompt_ids.append(episode["prompt_ids"])
281
  all_completion_ids.append(episode["completion_ids"])
282
  all_logprobs.append(episode["logprobs"])
 
370
  output_dir=OUTPUT_DIR,
371
  use_vllm=True,
372
  vllm_mode="colocate",
373
+ vllm_gpu_memory_utilization=0.3,
374
  num_train_epochs=1,
375
  num_generations=NUM_GENERATIONS,
376
+ max_completion_length=512,
377
+ per_device_train_batch_size=1,
378
+ gradient_accumulation_steps=32,
379
  learning_rate=1e-6,
380
  logging_steps=1,
381
  save_steps=50,