jjmachan's picture
Upload folder using huggingface_hub
8cad0d1 verified
"""GRPO training for customer support agent.
Trains Qwen3.5-9B with LoRA using TRL's GRPOTrainer.
Uses pre-collected trajectory data as prompts, with online generation + reward functions.
Usage:
python scripts/train_grpo.py --dataset-path ./data/quick_test/grpo_dataset
"""
import argparse
import json
import re
import sys
from pathlib import Path
from datasets import Dataset, load_from_disk
from peft import LoraConfig
from transformers import AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
# --- Reward functions ---
# These score model-generated completions during training.
# Signature: func(completions, **kwargs) -> list[float]
# completions = list of [{"role": "assistant", "content": "..."}]
TOOL_CALL_RE = re.compile(
r"<tool_call>\s*<function=(\w+)>(.*?)</function>\s*</tool_call>", re.DOTALL
)
PARAM_RE = re.compile(r"<parameter=(\w+)>(.*?)</parameter>", re.DOTALL)
VALID_TOOLS = {"lookup_customer", "check_order", "send_reply", "update_ticket"}
def _get_text(completion):
if isinstance(completion, list):
return completion[0]["content"] if completion else ""
return str(completion)
def _parse_answer(ans):
if isinstance(ans, str):
try:
return json.loads(ans)
except json.JSONDecodeError:
return {}
return ans if isinstance(ans, dict) else {}
def format_reward(completions, **kwargs) -> list[float]:
"""Valid XML tool call format?
+1.0 valid <tool_call> with function + params
+0.25 partial structure
-1.0 empty or no tool call
"""
rewards = []
for completion in completions:
text = _get_text(completion)
if not text.strip():
rewards.append(-1.0)
continue
match = TOOL_CALL_RE.search(text)
if match:
params = PARAM_RE.findall(match.group(2))
rewards.append(1.0 if params else 0.5)
elif "<tool_call>" in text or "<function=" in text:
rewards.append(0.25)
else:
rewards.append(-1.0)
return rewards
def tool_validity_reward(completions, answer=None, **kwargs) -> list[float]:
"""Is the tool name valid?
+1.0 valid tool name
-0.5 parseable but wrong tool
0.0 no tool call
"""
rewards = []
answers = answer or [None] * len(completions)
for completion, ans in zip(completions, answers):
text = _get_text(completion)
ans_data = _parse_answer(ans) if ans else {}
valid = set(ans_data.get("valid_tools", VALID_TOOLS))
match = TOOL_CALL_RE.search(text)
if match:
tool_name = match.group(1).strip()
rewards.append(1.0 if tool_name in valid else -0.5)
else:
rewards.append(0.0)
return rewards
def reasoning_reward(completions, **kwargs) -> list[float]:
"""Does the completion have reasoning before the tool call?
+1.0 reasoning text (>30 chars) before <tool_call>
+0.3 has tool call but no reasoning
-0.5 no tool call at all
"""
rewards = []
for completion in completions:
text = _get_text(completion)
tool_pos = text.find("<tool_call>")
if tool_pos > 30:
rewards.append(1.0)
elif tool_pos >= 0:
rewards.append(0.3)
else:
rewards.append(-0.5)
return rewards
def no_reasoning_leak_reward(completions, **kwargs) -> list[float]:
"""Penalize leaking reasoning into send_reply message.
The model should NOT send its internal reasoning as the customer message.
+1.0 send_reply message doesn't start with reasoning patterns
-1.0 send_reply message starts with "The customer", "I need to", etc.
0.0 not a send_reply (neutral)
"""
reasoning_starts = [
"The customer", "I need to", "Let me", "From the", "Based on",
"Looking at", "According to", "I should", "First,", "Now I",
"The user", "I'll", "I will",
]
rewards = []
for completion in completions:
text = _get_text(completion)
match = TOOL_CALL_RE.search(text)
if match and match.group(1).strip() == "send_reply":
# Check if the message parameter starts with reasoning
params = dict(PARAM_RE.findall(match.group(2)))
msg = params.get("message", "").strip()
if any(msg.startswith(p) for p in reasoning_starts):
rewards.append(-1.0)
else:
rewards.append(1.0)
else:
rewards.append(0.0) # Not a send_reply, neutral
return rewards
def action_quality_reward(completions, answer=None, **kwargs) -> list[float]:
"""Heuristic quality: references ground truth values.
+0.3 per ground truth referenced, capped at 1.0
"""
rewards = []
answers = answer or [None] * len(completions)
for completion, ans in zip(completions, answers):
text = _get_text(completion)
ans_data = _parse_answer(ans) if ans else {}
ground_truths = ans_data.get("ground_truth_values", [])
score = 0.0
for gt in ground_truths:
if gt and str(gt).lower() in text.lower():
score += 0.3
rewards.append(min(1.0, score))
return rewards
# --- Dataset preparation ---
def prepare_dataset(dataset_path: str) -> Dataset:
"""Load the collected trajectory dataset and format for GRPO.
GRPO needs: prompt (list of message dicts), plus any extra columns
for reward functions (like 'answer' with ground truth metadata).
"""
ds = load_from_disk(dataset_path)
print(f"Loaded dataset: {len(ds)} examples")
print(f"Columns: {ds.column_names}")
# The dataset has 'prompt' (list of messages) and 'answer' (JSON string)
# GRPOTrainer expects 'prompt' to be the conversation so far
# It will generate completions and evaluate with reward functions
# Ensure prompt is in the right format
# Each prompt should be a list of message dicts
return ds
# --- Main ---
def main():
parser = argparse.ArgumentParser(description="GRPO training for customer support agent")
parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Base model")
parser.add_argument("--dataset-path", default="./data/quick_test/grpo_dataset", help="HF dataset path")
parser.add_argument("--output-dir", default="./outputs/grpo-support-agent", help="Output directory")
parser.add_argument("--num-epochs", type=int, default=1, help="Training epochs")
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--num-generations", type=int, default=4, help="Completions per prompt")
parser.add_argument("--max-completion-length", type=int, default=512, help="Max generation tokens")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank")
args = parser.parse_args()
# Load dataset
dataset = prepare_dataset(args.dataset_path)
# Load tokenizer explicitly (Qwen3.5-9B is a VL model, so AutoProcessor
# would try multimodal content blocks — we need text-only AutoTokenizer)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# LoRA config
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
# GRPO config
training_args = GRPOConfig(
output_dir=args.output_dir,
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
num_train_epochs=args.num_epochs,
num_generations=args.num_generations,
max_completion_length=args.max_completion_length,
temperature=0.7,
beta=0.0, # No KL penalty
loss_type="grpo",
scale_rewards="group",
bf16=True,
logging_steps=5,
save_steps=100,
log_completions=True,
# Avoid OOM
max_grad_norm=1.0,
warmup_ratio=0.1,
)
# Reward functions (weighted)
reward_funcs = [
format_reward, # Valid XML tool call format
tool_validity_reward, # Correct tool name
reasoning_reward, # Reasoning before action
no_reasoning_leak_reward, # Don't leak reasoning into replies
action_quality_reward, # References ground truth values
]
reward_weights = [1.0, 0.8, 0.5, 0.7, 0.5]
print(f"\n=== GRPO Training ===")
print(f"Model: {args.model}")
print(f"Dataset: {len(dataset)} examples")
print(f"LoRA rank: {args.lora_r}")
print(f"Batch size: {args.batch_size} x {args.grad_accum} grad accum")
print(f"Generations per prompt: {args.num_generations}")
print(f"Reward functions: {[f.__name__ for f in reward_funcs]}")
print(f"Reward weights: {reward_weights}")
print()
trainer = GRPOTrainer(
model=args.model,
processing_class=tokenizer, # Force text-only tokenizer (not VL processor)
args=training_args,
reward_funcs=reward_funcs,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
# Save the LoRA adapter
trainer.save_model(args.output_dir + "/final_adapter")
print(f"\nTraining complete! Adapter saved to {args.output_dir}/final_adapter")
if __name__ == "__main__":
main()