import json from datasets import Dataset from unsloth import FastLanguageModel, is_bfloat16_supported from trl import GRPOConfig, GRPOTrainer from client import CustomerEnv from models import CustomerAction REMOTE_ENV_URL = "https://ramnarayanan747-voice-agent.hf.space" MODEL_PATH = "voice_agent_sft" MAX_SEQ_LENGTH = 1024 def openenv_reward_func(prompts, completions, **kwargs): """ The bridge between GRPO and your OpenEnv server. TRL passes the generated actions. We send them to the cloud via client.py, and return the exact reward the environment assigns. """ rewards = [] for response in completions: text = response[0]["content"] if isinstance(response, list) else response try: action_dict = json.loads(text.strip()) action_msg = json.dumps(action_dict) with CustomerEnv(base_url=REMOTE_ENV_URL) as env: env.reset() result = env.step(CustomerAction(message=action_msg)) rewards.append(float(result.reward)) except json.JSONDecodeError: # Major penalty if the model forgets its SFT training and outputs bad JSON rewards.append(-5.0) except Exception as e: # Minor penalty if the action is valid JSON but crashes the environment logic rewards.append(-2.0) return rewards SYSTEM_PROMPT = "You are a banking Voice Agent. You must output JSON actions using 'speak' or 'tool_call'." intents = [ "Customer sees a $215.50 charge from 'TechStore Online'.", "Customer lost their wallet on the subway 10 minutes ago.", "Customer wants to check their checking account balance." ] dataset = Dataset.from_dict({ "prompt": [ [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"System: Call connected.\nCustomer: {intent}"} ] for intent in intents ] }) print(f"Loading SFT model from {MODEL_PATH}...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_PATH, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=True, fast_inference=True, ) # Re-apply LoRA adapters for the RL phase model = FastLanguageModel.get_peft_model( model, r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], use_gradient_checkpointing="unsloth", ) # Configure the GRPO Trainer training_args = GRPOConfig( use_vllm=True, learning_rate=5e-6, # Keep RL learning rate much lower than SFT adam_beta1=0.9, adam_beta2=0.99, weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type="cosine", optim="paged_adamw_8bit", logging_steps=1, bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), per_device_train_batch_size=1, gradient_accumulation_steps=4, num_generations=4, # How many different actions to test per prompt max_prompt_length=256, max_completion_length=256, max_steps=200, output_dir="grpo_outputs", ) trainer = GRPOTrainer( model=model, reward_funcs=[openenv_reward_func], args=training_args, train_dataset=dataset, ) print("Starting RL loop over the remote OpenEnv environment...") trainer.train() print("Saving final RL-optimized agent...") model.save_pretrained("voice_agent_rl_final") tokenizer.save_pretrained("voice_agent_rl_final") print("Agent successfully trained!")