recruitopenenv / train_grpo.py
bathientran's picture
Upload folder using huggingface_hub
be37527 verified
"""
GRPO training script for the Driver Recruit Environment.
Uses TRL's GRPOTrainer with rollout_func for multi-turn episodes.
The model controls EVERY action in the episode via tool calls.
Usage:
python train_grpo.py --model Qwen/Qwen2.5-3B-Instruct --use-qlora
"""
import argparse
import json
import random
from datasets import Dataset
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
from recruitopenenv import RecruitopenenvEnv, RecruitopenenvAction
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions
# --- Prompt templates ---
SYSTEM_PROMPT = """You are a truck driver recruiter using a CRM system. You only know the driver's name. You must discover their qualifications through conversation, record info in the CRM, get approval, and hire them.
You have 4 tools:
## crm
- read_candidate: Read the current CRM record
- update_stage: Advance pipeline (contacted β†’ interested β†’ approval_pending β†’ offer_sent β†’ hired)
- update_field: Record info (field + value)
- add_note: Add a free-text note
## messaging
- send_message: Send a message (topic: greeting, call, experience, home_time, pay, equipment, route, deal_breakers, availability, violations, medical_card, references, pitch, offer, negotiate_pay, negotiate_home_time, signing_bonus, address_concern)
- read_reply: Read the driver's response
## approval
- request_approval: Request approval for a job (needs job_id)
- check_approval: Check approval status
## workflow
- wait: Advance time (needed for approval processing)
## Rules
- Must read CRM before messaging
- Must read_reply before sending another message
- Must request_approval and wait before sending offer
- Must follow stage order: lead β†’ contacted β†’ interested β†’ approval_pending β†’ offer_sent β†’ hired
- Record important info in CRM with update_field
- Too many messages hurt trust
## Workflow
1. crm.read_candidate
2. messaging.send_message (greeting/call) β†’ read_reply β†’ update_stage(contacted)
3. messaging.send_message (screening topics) β†’ read_reply β†’ crm.update_field
4. crm.update_stage(interested)
5. approval.request_approval β†’ workflow.wait β†’ approval.check_approval
6. crm.update_stage(approval_pending)
7. messaging.send_message(offer) β†’ read_reply
8. crm.update_stage(offer_sent) β†’ crm.update_stage(hired)
Respond with ONLY JSON:
{"tool": "crm", "action": "read_candidate"}
{"tool": "messaging", "action": "send_message", "topic": "experience"}
{"tool": "messaging", "action": "read_reply"}
{"tool": "crm", "action": "update_field", "field": "cdl_class", "value": "A"}
{"tool": "crm", "action": "update_stage", "stage": "contacted"}
{"tool": "approval", "action": "request_approval", "job_id": 2}
{"tool": "workflow", "action": "wait"}
{"tool": "approval", "action": "check_approval"}
{"tool": "messaging", "action": "send_message", "topic": "offer", "job_id": 2}
{"tool": "crm", "action": "update_stage", "stage": "hired"}"""
def format_observation(obs):
"""Format observation into a user prompt for the LLM."""
parts = [f"Driver: {obs.driver_name}"]
if obs.crm_summary:
parts.append(f"CRM:\n{obs.crm_summary}")
if obs.jobs_summary:
parts.append(f"Jobs:\n{obs.jobs_summary}")
if obs.discovered_info:
parts.append(f"Discovered:\n{obs.discovered_info}")
status = f"Stage: {obs.stage}"
if obs.pending_reply:
status += " | PENDING REPLY"
parts.append(status)
if obs.feedback:
parts.append(f"Result: {obs.feedback}")
return "\n".join(parts)
def format_observation_compact(obs):
"""Compact observation for embedding in completion_ids (~30-60 tokens)."""
parts = [f"Stage: {obs.stage}"]
if obs.pending_reply:
parts.append("PENDING REPLY")
if obs.feedback:
parts.append(obs.feedback[:200])
if obs.discovered_info:
parts.append(obs.discovered_info[:200])
return "\n".join(parts)
def parse_action(text):
"""Parse LLM output into a RecruitopenenvAction."""
text = text.strip()
# Remove markdown fences
if "```" in text:
for part in text.split("```"):
part = part.strip()
if part.startswith("json"):
part = part[4:].strip()
if part.startswith("{"):
text = part
break
# Try JSON
try:
data = json.loads(text)
if isinstance(data, list):
data = data[0] if data else {}
if isinstance(data, dict) and "tool" in data and "action" in data:
return RecruitopenenvAction(
tool=data["tool"],
action=data["action"],
topic=data.get("topic", ""),
job_id=int(data.get("job_id", -1)),
stage=str(data.get("stage", "")),
field=str(data.get("field", "")),
value=str(data.get("value", "")),
)
except (json.JSONDecodeError, KeyError, IndexError, ValueError, TypeError):
pass
# Fallback: try to detect intent
text_lower = text.lower()
if "read_candidate" in text_lower:
return RecruitopenenvAction(tool="crm", action="read_candidate")
if "read_reply" in text_lower:
return RecruitopenenvAction(tool="messaging", action="read_reply")
if "check_approval" in text_lower:
return RecruitopenenvAction(tool="approval", action="check_approval")
if "wait" in text_lower:
return RecruitopenenvAction(tool="workflow", action="wait")
# Default to reading CRM
return RecruitopenenvAction(tool="crm", action="read_candidate")
# --- Multi-turn rollout ---
ENV_URL = "http://localhost:8001"
MAX_COMPLETION_TOKENS = 1536
def _build_chat_transition(tokenizer, obs_text):
"""Build chat-formatted transition tokens: end assistant turn, user obs, start assistant.
Result: <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
This ensures the model sees proper chat structure during the forward pass.
"""
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
# Encode role tags and newlines
nl = tokenizer.encode("\n", add_special_tokens=False)
user_tag = tokenizer.encode("user", add_special_tokens=False)
asst_tag = tokenizer.encode("assistant", add_special_tokens=False)
obs_ids = tokenizer.encode(obs_text, add_special_tokens=False)[:60]
# <|im_end|>\n<|im_start|>user\n{obs}<|im_end|>\n<|im_start|>assistant\n
return (
[im_end] + nl +
[im_start] + user_tag + nl +
obs_ids +
[im_end] + nl +
[im_start] + asst_tag + nl
)
def rollout_once(trainer, env, tokenizer, prompt_text, system_prompt, max_turns=15):
"""Run one multi-turn episode with chat-formatted transitions.
completion_ids: [action1, <|im_end|>user obs<|im_start|>assistant, action2, ...]
The chat template structure lets the forward pass assign proper logprobs.
"""
seed = random.randint(0, 2**31 - 1)
result = env.reset(seed=seed)
obs = result.observation
prompt_ids = []
completion_ids = []
logprobs = []
env_mask = []
total_reward = 0.0
steps = 0
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": format_observation(obs)},
]
while not result.done and steps < max_turns:
# Check if we're near the token budget (need room for action + transition)
if len(completion_ids) > MAX_COMPLETION_TOKENS - 60:
break
current_prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
rollout_outputs = generate_rollout_completions(trainer, [current_prompt])[0]
if steps == 0:
prompt_ids = list(rollout_outputs["prompt_ids"])
action_ids = list(rollout_outputs["completion_ids"])
action_logprobs = list(rollout_outputs["logprobs"])
# Add action tokens (these get gradients)
completion_ids.extend(action_ids)
logprobs.extend(action_logprobs)
env_mask.extend([1] * len(action_ids))
response = rollout_outputs.get("text") or tokenizer.decode(
action_ids, skip_special_tokens=True
)
messages.append({"role": "assistant", "content": response})
action = parse_action(response)
result = env.step(action)
obs = result.observation
total_reward += result.reward
steps += 1
if not result.done:
# Build chat-formatted transition so forward pass sees proper structure
obs_text = format_observation_compact(obs)
transition_ids = _build_chat_transition(tokenizer, obs_text)
completion_ids.extend(transition_ids)
logprobs.extend([0.0] * len(transition_ids))
env_mask.extend([0] * len(transition_ids))
messages.append({"role": "user", "content": format_observation(obs)})
# Truncate to fit max_completion_length
completion_ids = completion_ids[:MAX_COMPLETION_TOKENS]
logprobs = logprobs[:MAX_COMPLETION_TOKENS]
env_mask = env_mask[:MAX_COMPLETION_TOKENS]
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"env_mask": env_mask,
"env_reward": total_reward,
"steps": steps,
"final_stage": obs.stage,
}
def rollout_func(prompts, trainer):
"""Multi-turn rollout: model controls every action in the episode."""
tokenizer = trainer.processing_class
env = RecruitopenenvEnv(base_url=ENV_URL)
all_prompt_ids = []
all_completion_ids = []
all_logprobs = []
all_env_rewards = []
all_env_mask = []
for prompt_text in prompts:
episode = rollout_once(trainer, env, tokenizer, prompt_text, SYSTEM_PROMPT)
if episode["completion_ids"]:
all_prompt_ids.append(episode["prompt_ids"])
all_completion_ids.append(episode["completion_ids"])
all_logprobs.append(episode["logprobs"])
all_env_mask.append(episode["env_mask"])
else:
tok_ids = tokenizer.encode("wait", add_special_tokens=False)
all_prompt_ids.append(episode["prompt_ids"] or tok_ids)
all_completion_ids.append(tok_ids)
all_logprobs.append([0.0] * len(tok_ids))
all_env_mask.append([1] * len(tok_ids))
all_env_rewards.append(episode["env_reward"])
print(f" Episode {len(all_env_rewards)}: reward={episode['env_reward']:.1f}, "
f"steps={episode['steps']}, stage={episode['final_stage']}")
env.close()
mean_r = sum(all_env_rewards) / len(all_env_rewards)
std_r = torch.tensor(all_env_rewards).std().item()
print(f"Rollout done: {len(all_env_rewards)} episodes, mean_reward={mean_r:.2f}, std={std_r:.2f}")
return {
"prompt_ids": all_prompt_ids,
"completion_ids": all_completion_ids,
"logprobs": [[(lp,) for lp in seq] for seq in all_logprobs],
"env_reward": all_env_rewards,
"env_mask": all_env_mask,
}
# --- Reward function (fallback, rewards come from rollout) ---
def reward_total(completions, **kwargs):
"""Extract environment rewards passed via rollout_func kwargs."""
env_rewards = kwargs.get("env_reward", [])
if env_rewards:
return [float(r) for r in env_rewards]
return [0.0] * len(completions)
# --- Main ---
def main():
parser = argparse.ArgumentParser(description="GRPO training for Driver Recruit Environment")
parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Model to train")
parser.add_argument("--env-url", default="http://localhost:8001", help="Environment server URL")
parser.add_argument("--num-episodes", type=int, default=16, help="Number of training episodes (dataset size)")
parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt")
parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--output-dir", default="./recruit-grpo-output", help="Output directory")
parser.add_argument("--vllm-mode", default="colocate", choices=["colocate", "server"],
help="vLLM mode: colocate (1 GPU) or server (2+ GPUs)")
parser.add_argument("--use-qlora", action="store_true", help="Use QLoRA (4-bit) for memory efficiency")
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
args = parser.parse_args()
global ENV_URL
ENV_URL = args.env_url
tokenizer = AutoTokenizer.from_pretrained(args.model)
prompts = []
env = RecruitopenenvEnv(base_url=args.env_url)
for i in range(args.num_episodes):
result = env.reset()
obs = result.observation
user_prompt = format_observation(obs)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
prompt_text = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
prompts.append(prompt_text)
env.close()
dataset = Dataset.from_dict({"prompt": prompts})
peft_config = None
model_kwargs = {}
if args.use_qlora:
from peft import LoraConfig
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
print(f"Using QLoRA: r={args.lora_r}, alpha={args.lora_alpha}, 4-bit")
grpo_config = GRPOConfig(
output_dir=args.output_dir,
use_vllm=True,
vllm_mode=args.vllm_mode,
num_train_epochs=args.epochs,
num_generations=args.num_generations,
max_completion_length=1536,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=args.lr,
temperature=0.7,
logging_steps=1,
save_steps=50,
bf16=True,
report_to="wandb",
run_name="recruit-grpo-tools",
model_init_kwargs=model_kwargs if model_kwargs else None,
)
trainer_kwargs = dict(
model=args.model,
processing_class=tokenizer,
reward_funcs=[reward_total],
train_dataset=dataset,
args=grpo_config,
rollout_func=rollout_func,
)
if peft_config is not None:
trainer_kwargs["peft_config"] = peft_config
trainer = GRPOTrainer(**trainer_kwargs)
print("=" * 50)
print(f"Training {args.model} (TOOL-BASED MULTI-TURN)")
print(f"Environment: {args.env_url}")
print(f"QLoRA: {args.use_qlora}")
print(f"Episodes: {args.num_episodes}")
print(f"Epochs: {args.epochs}")
print(f"Generations per prompt: {args.num_generations}")
print("=" * 50)
trainer.train()
trainer.save_model(args.output_dir)
print(f"\nModel saved to {args.output_dir}")
if __name__ == "__main__":
main()