trl-mcsd / examples /scripts /nemo_gym /train_multi_environment.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl[vllm]",
# "nemo_gym @ git+https://github.com/NVIDIA-NeMo/Gym",
# ]
# ///
import argparse
import asyncio
import json
import os
from dataclasses import dataclass
from typing import Any
import aiohttp
import requests
import yaml
from datasets import Dataset, load_dataset
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
@dataclass
class NeMoGymGRPOConfig(GRPOConfig):
agent_servers: dict[str, str] | None = None
request_timeout: float = 10800
def get_agent_servers(
head_server_host: str = "127.0.0.1",
head_server_port: int = 11000,
) -> dict[str, str]:
try:
response = requests.get(f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", timeout=10)
response.raise_for_status()
global_config_yaml = response.text
global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml))
agent_servers = {}
for server_name, server_config in global_config_dict.items():
if hasattr(server_config, "responses_api_agents"):
agents = server_config.responses_api_agents
for agent_key in agents.keys():
agent_config = getattr(agents, agent_key)
if hasattr(agent_config, "host") and hasattr(agent_config, "port"):
agent_host = agent_config.host
if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"):
agent_host = head_server_host
agent_servers[server_name] = f"http://{agent_host}:{agent_config.port}"
if not agent_servers:
raise ValueError("No agents found in global config")
return agent_servers
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") from e
def reward_fn(completions: list[str], **kwargs) -> list[float]:
env_rewards = kwargs.get("env_reward")
assert env_rewards is not None, "env_reward not found in kwargs"
return [float(r) for r in env_rewards]
async def call_nemo_gym_agents(
prompts: list[str],
dataset_items: list[dict[str, Any]],
agent_servers: dict[str, str],
timeout: float,
max_completion_length: int = 4096,
temperature: float = 1.0,
top_p: float = 0.999,
) -> list[dict[str, Any]]:
async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session:
tasks = []
for prompt, item in zip(prompts, dataset_items, strict=False):
request_body = item.copy()
if "responses_create_params" not in request_body:
request_body["responses_create_params"] = {
"input": [{"role": "user", "content": prompt}],
}
params = request_body["responses_create_params"]
params.setdefault("max_output_tokens", max_completion_length)
params["temperature"] = temperature
params["top_p"] = top_p
agent_ref = item.get("agent_ref", {})
agent_name = agent_ref.get("name") if isinstance(agent_ref, dict) else None
if not agent_name or agent_name not in agent_servers:
raise ValueError(
f"Missing or invalid agent_ref. Got: {agent_ref}. Available: {list(agent_servers.keys())}"
)
agent_url = agent_servers[agent_name]
task = session.post(
f"{agent_url}/run",
json=request_body,
timeout=aiohttp.ClientTimeout(total=timeout),
)
tasks.append(task)
responses = await asyncio.gather(*tasks, return_exceptions=True)
results = []
for i, response in enumerate(responses):
try:
if isinstance(response, Exception):
raise response
json_data = await response.json()
if not isinstance(json_data, dict):
raise ValueError(f"Expected dict, got {type(json_data)}")
results.append(json_data)
except Exception as e:
print(f"WARNING: Request {i} failed: {e}")
results.append({"response": {"output": []}, "reward": 0.0, "error": str(e)})
return results
def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
is_eval = not trainer.model.training
num_generations = (
trainer.args.num_generations_eval
if is_eval and trainer.args.num_generations_eval
else trainer.args.num_generations
)
dataset = trainer.eval_dataset if is_eval and trainer.eval_dataset is not None else trainer.train_dataset
expanded_prompts = []
expanded_dataset_items = []
for idx_str in prompts:
idx = int(idx_str)
item = json.loads(dataset[idx]["metadata"])
for _ in range(num_generations):
expanded_prompts.append(idx_str)
expanded_dataset_items.append(dict(item))
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
responses = loop.run_until_complete(
call_nemo_gym_agents(
expanded_prompts,
expanded_dataset_items,
trainer.args.agent_servers,
trainer.args.request_timeout,
trainer.args.max_completion_length,
temperature=trainer.args.temperature,
top_p=trainer.args.top_p,
)
)
finally:
loop.close()
tokenizer = trainer.processing_class
prompt_ids: list[list[int]] = []
completion_ids: list[list[int]] = [] # list of rollouts
env_mask: list[list[int]] = [] # only train on assistant turns
logprobs: list[list[float]] = []
env_rewards: list[float] = []
num_turns_list: list[int] = []
for i, response in enumerate(responses):
eos_token_id = tokenizer.eos_token_id or 0
if not isinstance(response, dict) or response.get("error"):
rollout_failed = True
else:
output_items = response.get("response", {}).get("output", [])
has_content = output_items and any(
item.get("type") == "function_call"
or (
item.get("type") == "message"
and any(
c.get("type") == "output_text" and c.get("text", "").strip() for c in item.get("content", [])
)
)
for item in output_items
)
rollout_failed = not has_content
if rollout_failed:
prompt_ids.append([eos_token_id])
completion_ids.append([eos_token_id])
env_mask.append([0])
logprobs.append([0.0])
env_rewards.append(0.0)
num_turns_list.append(0)
continue
episode_reward = response.get("reward", 0.0)
output_items = response.get("response", {}).get("output", [])
rollout_ids: list[int] = []
rollout_mask: list[int] = []
rollout_logprobs: list[float] = []
seen_token_ids: list[int] = []
first_prompt = None
num_turns = 0
for _idx, item in enumerate(output_items):
if "prompt_token_ids" not in item or "generation_token_ids" not in item:
continue
num_turns += 1
item_prompt_ids = item["prompt_token_ids"]
item_gen_ids = item["generation_token_ids"]
item_logprobs = item.get("generation_log_probs", [])
tool_result_tokens = []
if first_prompt is None:
first_prompt = item_prompt_ids
seen_token_ids = list(item_prompt_ids)
else:
if len(item_prompt_ids) > len(seen_token_ids):
if item_prompt_ids[: len(seen_token_ids)] != seen_token_ids:
raise ValueError(
f"[Turn {num_turns}] Non-contiguous messages (tokenization issue). "
f"Expected prefix len {len(seen_token_ids)}, got prompt len {len(item_prompt_ids)}"
)
tool_result_tokens = item_prompt_ids[len(seen_token_ids) :]
if tool_result_tokens:
rollout_ids.extend(tool_result_tokens)
rollout_mask.extend([0] * len(tool_result_tokens))
rollout_logprobs.extend([0.0] * len(tool_result_tokens))
rollout_ids.extend(item_gen_ids)
rollout_mask.extend([1] * len(item_gen_ids))
assert len(item_logprobs) == len(item_gen_ids), (
f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}"
)
rollout_logprobs.extend(item_logprobs)
seen_token_ids = list(item_prompt_ids) + list(item_gen_ids)
if not rollout_ids or first_prompt is None:
raise ValueError(f"Rollout {i} has no valid turns")
prompt_ids.append(first_prompt) # list of prompts
completion_ids.append(rollout_ids) # list of rollouts
env_mask.append(rollout_mask)
logprobs.append(rollout_logprobs)
env_rewards.append(episode_reward)
num_turns_list.append(num_turns)
if not prompt_ids:
raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.")
if num_turns_list:
trainer.log(
{
"num_turns_mean": sum(num_turns_list) / len(num_turns_list),
"num_turns_min": min(num_turns_list),
"num_turns_max": max(num_turns_list),
}
)
unique_prompt_ids = prompt_ids[::num_generations]
return {
"prompt_ids": unique_prompt_ids,
"completion_ids": completion_ids,
"env_mask": env_mask,
"logprobs": logprobs,
"env_reward": env_rewards,
"num_turns": num_turns_list,
}
def load_dataset_from_jsonl(path: str) -> Dataset:
data = []
with open(path) as f:
for idx, line in enumerate(f):
if line.strip():
item = json.loads(line)
data.append(
{
"prompt": str(
idx
), # use index for lookup as not all nemo gym datasets have the same metadata fields. maybe not the most elegant
"metadata": json.dumps(item),
}
)
return Dataset.from_list(data)
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("--config", required=True, help="Path to config YAML file")
parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", help="vLLM server hostname/IP")
parser.add_argument("--head_server_host", type=str, default="127.0.0.1", help="Head server hostname/IP for ng_run")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to resume from")
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
model_name = config.pop("model_name")
dataset_path = config.pop("dataset_path")
eval_dataset_path = config.pop("eval_dataset_path", None)
task = config.pop("task", None)
project_name = config.pop("project_name", None)
if "learning_rate" in config and isinstance(config["learning_rate"], str):
config["learning_rate"] = float(config["learning_rate"])
if "weight_decay" in config and isinstance(config["weight_decay"], str):
config["weight_decay"] = float(config["weight_decay"])
agent_servers = get_agent_servers(
head_server_host=args.head_server_host,
head_server_port=11000,
)
if project_name:
os.environ["WANDB_PROJECT"] = project_name
if dataset_path.endswith((".jsonl", ".json")):
dataset = load_dataset_from_jsonl(dataset_path)
else:
dataset = load_dataset(dataset_path, split="train")
eval_dataset = None
if eval_dataset_path:
eval_dataset = load_dataset_from_jsonl(eval_dataset_path)
print(f"Eval dataset has {len(eval_dataset)} examples\n")
training_args = NeMoGymGRPOConfig(
use_vllm=True,
vllm_mode="server",
vllm_server_host=args.vllm_server_host,
vllm_server_port=8000,
gradient_checkpointing=True,
num_generations_eval=1,
logging_steps=1,
epsilon=0.2,
epsilon_high=0.28,
loss_type="grpo",
mask_truncated_completions=True,
shuffle_dataset=False,
model_init_kwargs={"torch_dtype": "auto"},
agent_servers=agent_servers,
request_timeout=10800,
**config,
)
if training_args.run_name is None:
task_name = task or os.path.basename(dataset_path).replace(".jsonl", "").replace(".json", "")
model_short = model_name.split("/")[-1]
training_args.run_name = (
f"{task_name}_{model_short}"
f"_rpp{training_args.num_generations}"
f"_dbs{training_args.per_device_train_batch_size}"
f"_ga{training_args.gradient_accumulation_steps}"
f"_maxlen{training_args.max_completion_length}"
f"_lr{training_args.learning_rate}"
f"_temp{training_args.temperature}"
f"_topp{training_args.top_p}"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left", padding_side="left")
trainer = GRPOTrainer(
model=model_name,
processing_class=tokenizer,
reward_funcs=reward_fn,
train_dataset=dataset,
eval_dataset=eval_dataset,
rollout_func=nemo_gym_rollout_func,
args=training_args,
)
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
if __name__ == "__main__":
main()