Ram Narayanan commited on
Commit
ba88319
·
1 Parent(s): 8543882

Added placeholder train_sft and train_rl (both should work)

Browse files
Files changed (2) hide show
  1. train_rl.py +113 -0
  2. train_sft.py +65 -0
train_rl.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datasets import Dataset
3
+ from unsloth import FastLanguageModel, is_bfloat16_supported
4
+ from trl import GRPOConfig, GRPOTrainer
5
+
6
+ from client import CustomerEnv
7
+ from models import CustomerAction
8
+
9
+ REMOTE_ENV_URL = "https://ramnarayanan747-voice-agent.hf.space"
10
+
11
+ MODEL_PATH = "voice_agent_sft"
12
+ MAX_SEQ_LENGTH = 1024
13
+
14
+ def openenv_reward_func(prompts, completions, **kwargs):
15
+ """
16
+ The bridge between GRPO and your OpenEnv server.
17
+ TRL passes the generated actions. We send them to the cloud via client.py,
18
+ and return the exact reward the environment assigns.
19
+ """
20
+ rewards = []
21
+
22
+ for response in completions:
23
+ text = response[0]["content"] if isinstance(response, list) else response
24
+
25
+ try:
26
+ action_dict = json.loads(text.strip())
27
+ action_msg = json.dumps(action_dict)
28
+
29
+ with CustomerEnv(base_url=REMOTE_ENV_URL) as env:
30
+ env.reset()
31
+
32
+ result = env.step(CustomerAction(message=action_msg))
33
+
34
+ rewards.append(float(result.reward))
35
+
36
+ except json.JSONDecodeError:
37
+ # Major penalty if the model forgets its SFT training and outputs bad JSON
38
+ rewards.append(-5.0)
39
+ except Exception as e:
40
+ # Minor penalty if the action is valid JSON but crashes the environment logic
41
+ rewards.append(-2.0)
42
+
43
+ return rewards
44
+
45
+ SYSTEM_PROMPT = "You are a banking Voice Agent. You must output JSON actions using 'speak' or 'tool_call'."
46
+ intents = [
47
+ "Customer sees a $215.50 charge from 'TechStore Online'.",
48
+ "Customer lost their wallet on the subway 10 minutes ago.",
49
+ "Customer wants to check their checking account balance."
50
+ ]
51
+
52
+ dataset = Dataset.from_dict({
53
+ "prompt": [
54
+ [
55
+ {"role": "system", "content": SYSTEM_PROMPT},
56
+ {"role": "user", "content": f"System: Call connected.\nCustomer: {intent}"}
57
+ ]
58
+ for intent in intents
59
+ ]
60
+ })
61
+
62
+ print(f"Loading SFT model from {MODEL_PATH}...")
63
+ model, tokenizer = FastLanguageModel.from_pretrained(
64
+ model_name=MODEL_PATH,
65
+ max_seq_length=MAX_SEQ_LENGTH,
66
+ load_in_4bit=True,
67
+ fast_inference=True,
68
+ )
69
+
70
+ # Re-apply LoRA adapters for the RL phase
71
+ model = FastLanguageModel.get_peft_model(
72
+ model,
73
+ r=16,
74
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
75
+ use_gradient_checkpointing="unsloth",
76
+ )
77
+
78
+ # Configure the GRPO Trainer
79
+ training_args = GRPOConfig(
80
+ use_vllm=True,
81
+ learning_rate=5e-6, # Keep RL learning rate much lower than SFT
82
+ adam_beta1=0.9,
83
+ adam_beta2=0.99,
84
+ weight_decay=0.1,
85
+ warmup_ratio=0.1,
86
+ lr_scheduler_type="cosine",
87
+ optim="paged_adamw_8bit",
88
+ logging_steps=1,
89
+ bf16=is_bfloat16_supported(),
90
+ fp16=not is_bfloat16_supported(),
91
+ per_device_train_batch_size=1,
92
+ gradient_accumulation_steps=4,
93
+ num_generations=4, # How many different actions to test per prompt
94
+ max_prompt_length=256,
95
+ max_completion_length=256,
96
+ max_steps=200,
97
+ output_dir="grpo_outputs",
98
+ )
99
+
100
+ trainer = GRPOTrainer(
101
+ model=model,
102
+ reward_funcs=[openenv_reward_func],
103
+ args=training_args,
104
+ train_dataset=dataset,
105
+ )
106
+
107
+ print("Starting RL loop over the remote OpenEnv environment...")
108
+ trainer.train()
109
+
110
+ print("Saving final RL-optimized agent...")
111
+ model.save_pretrained("voice_agent_rl_final")
112
+ tokenizer.save_pretrained("voice_agent_rl_final")
113
+ print("Agent successfully trained!")
train_sft.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from unsloth.chat_templates import get_chat_template, train_on_responses_only
3
+ from datasets import load_dataset
4
+ from trl import SFTTrainer
5
+ from transformers import TrainingArguments
6
+
7
+ model, tokenizer = FastLanguageModel.from_pretrained(
8
+ model_name="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
9
+ max_seq_length=2048,
10
+ load_in_4bit=True,
11
+ )
12
+ tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")
13
+
14
+ model = FastLanguageModel.get_peft_model(
15
+ model,
16
+ r=16,
17
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
18
+ lora_alpha=16,
19
+ lora_dropout=0,
20
+ bias="none",
21
+ use_gradient_checkpointing="unsloth",
22
+ )
23
+
24
+ dataset = load_dataset("json", data_files="sft_data.json", split="train")
25
+
26
+ def format_prompts(examples):
27
+ convos = examples["conversations"]
28
+ texts = [tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False) for c in convos]
29
+ return {"text": texts}
30
+
31
+ dataset = dataset.map(format_prompts, batched=True)
32
+
33
+ trainer = SFTTrainer(
34
+ model=model,
35
+ tokenizer=tokenizer,
36
+ train_dataset=dataset,
37
+ dataset_text_field="text",
38
+ max_seq_length=2048,
39
+ args=TrainingArguments(
40
+ per_device_train_batch_size=2,
41
+ gradient_accumulation_steps=4,
42
+ warmup_steps=5,
43
+ max_steps=150,
44
+ learning_rate=2e-4,
45
+ fp16=not FastLanguageModel.is_bfloat16_supported(),
46
+ bf16=FastLanguageModel.is_bfloat16_supported(),
47
+ logging_steps=10,
48
+ optim="adamw_8bit",
49
+ output_dir="sft_outputs",
50
+ seed=3407,
51
+ ),
52
+ )
53
+
54
+ trainer = train_on_responses_only(
55
+ trainer,
56
+ instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
57
+ response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
58
+ )
59
+
60
+ print("Starting Supervised Fine-Tuning...")
61
+ trainer.train()
62
+
63
+ model.save_pretrained("voice_agent_sft")
64
+ tokenizer.save_pretrained("voice_agent_sft")
65
+ print("SFT complete! Base model saved to ./voice_agent_sft")