Spaces:
Sleeping
Sleeping
| import json | |
| from datasets import Dataset | |
| from unsloth import FastLanguageModel, is_bfloat16_supported | |
| from trl import GRPOConfig, GRPOTrainer | |
| from client import CustomerEnv | |
| from models import CustomerAction | |
| REMOTE_ENV_URL = "https://ramnarayanan747-voice-agent.hf.space" | |
| MODEL_PATH = "voice_agent_sft" | |
| MAX_SEQ_LENGTH = 1024 | |
| def openenv_reward_func(prompts, completions, **kwargs): | |
| """ | |
| The bridge between GRPO and your OpenEnv server. | |
| TRL passes the generated actions. We send them to the cloud via client.py, | |
| and return the exact reward the environment assigns. | |
| """ | |
| rewards = [] | |
| for response in completions: | |
| text = response[0]["content"] if isinstance(response, list) else response | |
| try: | |
| action_dict = json.loads(text.strip()) | |
| action_msg = json.dumps(action_dict) | |
| with CustomerEnv(base_url=REMOTE_ENV_URL) as env: | |
| env.reset() | |
| result = env.step(CustomerAction(message=action_msg)) | |
| rewards.append(float(result.reward)) | |
| except json.JSONDecodeError: | |
| # Major penalty if the model forgets its SFT training and outputs bad JSON | |
| rewards.append(-5.0) | |
| except Exception as e: | |
| # Minor penalty if the action is valid JSON but crashes the environment logic | |
| rewards.append(-2.0) | |
| return rewards | |
| SYSTEM_PROMPT = "You are a banking Voice Agent. You must output JSON actions using 'speak' or 'tool_call'." | |
| intents = [ | |
| "Customer sees a $215.50 charge from 'TechStore Online'.", | |
| "Customer lost their wallet on the subway 10 minutes ago.", | |
| "Customer wants to check their checking account balance." | |
| ] | |
| dataset = Dataset.from_dict({ | |
| "prompt": [ | |
| [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"System: Call connected.\nCustomer: {intent}"} | |
| ] | |
| for intent in intents | |
| ] | |
| }) | |
| print(f"Loading SFT model from {MODEL_PATH}...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_PATH, | |
| max_seq_length=MAX_SEQ_LENGTH, | |
| load_in_4bit=True, | |
| fast_inference=True, | |
| ) | |
| # Re-apply LoRA adapters for the RL phase | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| # Configure the GRPO Trainer | |
| training_args = GRPOConfig( | |
| use_vllm=True, | |
| learning_rate=5e-6, # Keep RL learning rate much lower than SFT | |
| adam_beta1=0.9, | |
| adam_beta2=0.99, | |
| weight_decay=0.1, | |
| warmup_ratio=0.1, | |
| lr_scheduler_type="cosine", | |
| optim="paged_adamw_8bit", | |
| logging_steps=1, | |
| bf16=is_bfloat16_supported(), | |
| fp16=not is_bfloat16_supported(), | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| num_generations=4, # How many different actions to test per prompt | |
| max_prompt_length=256, | |
| max_completion_length=256, | |
| max_steps=200, | |
| output_dir="grpo_outputs", | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=[openenv_reward_func], | |
| args=training_args, | |
| train_dataset=dataset, | |
| ) | |
| print("Starting RL loop over the remote OpenEnv environment...") | |
| trainer.train() | |
| print("Saving final RL-optimized agent...") | |
| model.save_pretrained("voice_agent_rl_final") | |
| tokenizer.save_pretrained("voice_agent_rl_final") | |
| print("Agent successfully trained!") |