RL-IVR / train_rl.py
hrajgarhia943's picture
Upload folder using huggingface_hub
13eda4d verified
import json
from datasets import Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import GRPOConfig, GRPOTrainer
from client import CustomerEnv
from models import CustomerAction
REMOTE_ENV_URL = "https://ramnarayanan747-voice-agent.hf.space"
MODEL_PATH = "voice_agent_sft"
MAX_SEQ_LENGTH = 1024
def openenv_reward_func(prompts, completions, **kwargs):
"""
The bridge between GRPO and your OpenEnv server.
TRL passes the generated actions. We send them to the cloud via client.py,
and return the exact reward the environment assigns.
"""
rewards = []
for response in completions:
text = response[0]["content"] if isinstance(response, list) else response
try:
action_dict = json.loads(text.strip())
action_msg = json.dumps(action_dict)
with CustomerEnv(base_url=REMOTE_ENV_URL) as env:
env.reset()
result = env.step(CustomerAction(message=action_msg))
rewards.append(float(result.reward))
except json.JSONDecodeError:
# Major penalty if the model forgets its SFT training and outputs bad JSON
rewards.append(-5.0)
except Exception as e:
# Minor penalty if the action is valid JSON but crashes the environment logic
rewards.append(-2.0)
return rewards
SYSTEM_PROMPT = "You are a banking Voice Agent. You must output JSON actions using 'speak' or 'tool_call'."
intents = [
"Customer sees a $215.50 charge from 'TechStore Online'.",
"Customer lost their wallet on the subway 10 minutes ago.",
"Customer wants to check their checking account balance."
]
dataset = Dataset.from_dict({
"prompt": [
[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"System: Call connected.\nCustomer: {intent}"}
]
for intent in intents
]
})
print(f"Loading SFT model from {MODEL_PATH}...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_PATH,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=True,
fast_inference=True,
)
# Re-apply LoRA adapters for the RL phase
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
use_gradient_checkpointing="unsloth",
)
# Configure the GRPO Trainer
training_args = GRPOConfig(
use_vllm=True,
learning_rate=5e-6, # Keep RL learning rate much lower than SFT
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=4, # How many different actions to test per prompt
max_prompt_length=256,
max_completion_length=256,
max_steps=200,
output_dir="grpo_outputs",
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[openenv_reward_func],
args=training_args,
train_dataset=dataset,
)
print("Starting RL loop over the remote OpenEnv environment...")
trainer.train()
print("Saving final RL-optimized agent...")
model.save_pretrained("voice_agent_rl_final")
tokenizer.save_pretrained("voice_agent_rl_final")
print("Agent successfully trained!")