Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- train_demo.py +18 -11
train_demo.py
CHANGED
|
@@ -13,6 +13,7 @@ import hashlib
|
|
| 13 |
import re
|
| 14 |
import os
|
| 15 |
|
|
|
|
| 16 |
from datasets import Dataset
|
| 17 |
from trl import GRPOConfig, GRPOTrainer
|
| 18 |
from trl.experimental.openenv import generate_rollout_completions
|
|
@@ -23,7 +24,7 @@ from skill_invocation_env.client import SkillInvocationEnv
|
|
| 23 |
from skill_invocation_env.models import SkillInvocationAction
|
| 24 |
|
| 25 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
-
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-
|
| 27 |
ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
|
| 28 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 29 |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
|
|
@@ -264,12 +265,18 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
|
|
| 264 |
seed = _extract_seed(prompt_text)
|
| 265 |
|
| 266 |
env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
all_prompt_ids.append(episode["prompt_ids"])
|
| 274 |
all_completion_ids.append(episode["completion_ids"])
|
| 275 |
all_logprobs.append(episode["logprobs"])
|
|
@@ -363,12 +370,12 @@ if __name__ == "__main__":
|
|
| 363 |
output_dir=OUTPUT_DIR,
|
| 364 |
use_vllm=True,
|
| 365 |
vllm_mode="colocate",
|
| 366 |
-
vllm_gpu_memory_utilization=0.
|
| 367 |
num_train_epochs=1,
|
| 368 |
num_generations=NUM_GENERATIONS,
|
| 369 |
-
max_completion_length=
|
| 370 |
-
per_device_train_batch_size=
|
| 371 |
-
gradient_accumulation_steps=
|
| 372 |
learning_rate=1e-6,
|
| 373 |
logging_steps=1,
|
| 374 |
save_steps=50,
|
|
|
|
| 13 |
import re
|
| 14 |
import os
|
| 15 |
|
| 16 |
+
import wandb
|
| 17 |
from datasets import Dataset
|
| 18 |
from trl import GRPOConfig, GRPOTrainer
|
| 19 |
from trl.experimental.openenv import generate_rollout_completions
|
|
|
|
| 24 |
from skill_invocation_env.models import SkillInvocationAction
|
| 25 |
|
| 26 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
|
| 28 |
ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
|
| 29 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 30 |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
|
|
|
|
| 265 |
seed = _extract_seed(prompt_text)
|
| 266 |
|
| 267 |
env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
|
| 268 |
+
try:
|
| 269 |
+
episode = rollout_once(
|
| 270 |
+
trainer=trainer,
|
| 271 |
+
env=env,
|
| 272 |
+
tokenizer=tokenizer,
|
| 273 |
+
env_seed=seed,
|
| 274 |
+
)
|
| 275 |
+
finally:
|
| 276 |
+
try:
|
| 277 |
+
env.close()
|
| 278 |
+
except Exception:
|
| 279 |
+
pass
|
| 280 |
all_prompt_ids.append(episode["prompt_ids"])
|
| 281 |
all_completion_ids.append(episode["completion_ids"])
|
| 282 |
all_logprobs.append(episode["logprobs"])
|
|
|
|
| 370 |
output_dir=OUTPUT_DIR,
|
| 371 |
use_vllm=True,
|
| 372 |
vllm_mode="colocate",
|
| 373 |
+
vllm_gpu_memory_utilization=0.3,
|
| 374 |
num_train_epochs=1,
|
| 375 |
num_generations=NUM_GENERATIONS,
|
| 376 |
+
max_completion_length=512,
|
| 377 |
+
per_device_train_batch_size=1,
|
| 378 |
+
gradient_accumulation_steps=32,
|
| 379 |
learning_rate=1e-6,
|
| 380 |
logging_steps=1,
|
| 381 |
save_steps=50,
|