AaryanK commited on
Commit
e438405
Β·
verified Β·
1 Parent(s): 246112a

Upload grpo_run_nocot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_run_nocot.py +206 -0
grpo_run_nocot.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Fine-Tune WITHOUT Chain-of-Thought
4
+ Trains Arch-Router-1.5B to output just {"route": "..."} with better accuracy,
5
+ no reasoning overhead.
6
+ """
7
+
8
+ from unsloth import FastLanguageModel, is_bfloat16_supported
9
+ import torch
10
+ import re
11
+ import json
12
+ from datasets import Dataset
13
+ from collections import Counter
14
+
15
+ # ── Model loading ──
16
+ max_seq_length = 512
17
+ lora_rank = 32
18
+
19
+ model, tokenizer = FastLanguageModel.from_pretrained(
20
+ model_name="katanemo/Arch-Router-1.5B",
21
+ max_seq_length=max_seq_length,
22
+ load_in_4bit=True,
23
+ fast_inference=True,
24
+ max_lora_rank=lora_rank,
25
+ gpu_memory_utilization=0.6,
26
+ )
27
+
28
+ model = FastLanguageModel.get_peft_model(
29
+ model,
30
+ r=lora_rank,
31
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
32
+ lora_alpha=lora_rank,
33
+ use_gradient_checkpointing="unsloth",
34
+ random_state=3407,
35
+ )
36
+
37
+ # ── Route policies ──
38
+ ROUTE_POLICIES = [
39
+ {"name": "simple", "description": "Simple factual questions, greetings, basic lookups, yes/no answers, FAQ-style queries, single-step tasks, status checks, straightforward requests"},
40
+ {"name": "medium", "description": "Multi-step reasoning, summarization of moderate-length text, data extraction, moderate analysis, comparison tasks, troubleshooting, explanations requiring some depth"},
41
+ {"name": "complex", "description": "Complex multi-document reasoning, deep analysis, legal or financial interpretation, creative writing, code generation, multi-constraint problem solving, liability assessment, comprehensive evaluation"},
42
+ ]
43
+
44
+ # System prompt - NO chain of thought, just direct JSON output
45
+ SYSTEM_PROMPT = f"""You are a routing assistant. Given the route policies and user message, select the best matching route.
46
+
47
+ <route_policies>
48
+ {json.dumps(ROUTE_POLICIES)}
49
+ </route_policies>
50
+
51
+ Select the best route for this user message. Respond with ONLY valid JSON: {{"route": "route_name"}}"""
52
+
53
+
54
+ def extract_route(text: str) -> str | None:
55
+ try:
56
+ parsed = json.loads(text.strip())
57
+ route = parsed.get("route")
58
+ if route in ("simple", "medium", "complex"):
59
+ return route
60
+ except (json.JSONDecodeError, TypeError):
61
+ pass
62
+ for tier in ("simple", "medium", "complex"):
63
+ if tier in text.lower():
64
+ return tier
65
+ return None
66
+
67
+
68
+ # ── Load training data ──
69
+ import os
70
+ DATA_PATHS = ["scripts/grpo_training_data.json", "grpo_training_data.json", "/content/grpo_training_data.json"]
71
+ data_path = next((p for p in DATA_PATHS if os.path.exists(p)), None)
72
+ if data_path is None:
73
+ raise FileNotFoundError("Training data not found")
74
+
75
+ with open(data_path) as f:
76
+ raw_data = json.load(f)
77
+
78
+ print(f"Loaded {len(raw_data)} training examples")
79
+
80
+ formatted = []
81
+ for item in raw_data:
82
+ formatted.append({
83
+ "prompt": [
84
+ {"role": "system", "content": SYSTEM_PROMPT},
85
+ {"role": "user", "content": item["prompt"]},
86
+ ],
87
+ "answer": item["expected_route"],
88
+ })
89
+
90
+ dataset = Dataset.from_list(formatted)
91
+ print(f"Route distribution: {dict(Counter(item['expected_route'] for item in raw_data))}")
92
+
93
+
94
+ # ── Reward functions (no XML/format rewards - just correctness + valid JSON) ──
95
+ def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
96
+ responses = [completion[0]["content"] for completion in completions]
97
+ extracted = [extract_route(r) for r in responses]
98
+ q = prompts[0][-1]["content"]
99
+ print(f"--- Q: {q[:60]} | Expected: {answer[0]} | Got: {extracted[0]} | Raw: {responses[0][:80]}")
100
+ return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)]
101
+
102
+
103
+ def valid_route_reward_func(completions, **kwargs) -> list[float]:
104
+ responses = [completion[0]["content"] for completion in completions]
105
+ extracted = [extract_route(r) for r in responses]
106
+ return [0.5 if r in ("simple", "medium", "complex") else 0.0 for r in extracted]
107
+
108
+
109
+ def json_format_reward_func(completions, **kwargs) -> list[float]:
110
+ responses = [completion[0]["content"] for completion in completions]
111
+ rewards = []
112
+ for r in responses:
113
+ try:
114
+ parsed = json.loads(r.strip())
115
+ if "route" in parsed:
116
+ rewards.append(1.0) # Higher reward for clean JSON
117
+ else:
118
+ rewards.append(0.2)
119
+ except (json.JSONDecodeError, TypeError):
120
+ rewards.append(0.0)
121
+ return rewards
122
+
123
+
124
+ def brevity_reward_func(completions, **kwargs) -> list[float]:
125
+ """Reward shorter outputs β€” we want just the JSON, nothing else."""
126
+ responses = [completion[0]["content"] for completion in completions]
127
+ rewards = []
128
+ for r in responses:
129
+ length = len(r.strip())
130
+ if length <= 25: # {"route": "complex"} is 21 chars
131
+ rewards.append(0.5)
132
+ elif length <= 50:
133
+ rewards.append(0.2)
134
+ else:
135
+ rewards.append(0.0)
136
+ return rewards
137
+
138
+
139
+ # ── Training ──
140
+ from trl import GRPOConfig, GRPOTrainer
141
+
142
+ training_args = GRPOConfig(
143
+ use_vllm=True,
144
+ learning_rate=5e-6,
145
+ adam_beta1=0.9,
146
+ adam_beta2=0.99,
147
+ weight_decay=0.1,
148
+ warmup_ratio=0.1,
149
+ lr_scheduler_type="cosine",
150
+ optim="adamw_8bit",
151
+ logging_steps=1,
152
+ per_device_train_batch_size=1,
153
+ gradient_accumulation_steps=1,
154
+ num_generations=4,
155
+ max_prompt_length=384,
156
+ max_completion_length=64, # Much shorter - just need JSON output
157
+ max_steps=150,
158
+ save_steps=150,
159
+ max_grad_norm=0.1,
160
+ report_to="none",
161
+ output_dir="outputs_modelgate_nocot",
162
+ )
163
+
164
+ trainer = GRPOTrainer(
165
+ model=model,
166
+ processing_class=tokenizer,
167
+ reward_funcs=[
168
+ json_format_reward_func,
169
+ valid_route_reward_func,
170
+ brevity_reward_func,
171
+ correctness_reward_func,
172
+ ],
173
+ args=training_args,
174
+ train_dataset=dataset,
175
+ )
176
+ trainer.train()
177
+
178
+ # ── Save ──
179
+ model.save_pretrained("modelgate_arch_router_nocot_lora")
180
+ tokenizer.save_pretrained("modelgate_arch_router_nocot_lora")
181
+ print("\nLoRA adapter saved to modelgate_arch_router_nocot_lora/")
182
+
183
+ # ── Quick test ──
184
+ from vllm import SamplingParams
185
+
186
+ model.save_lora("modelgate_nocot_test_lora")
187
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=30)
188
+
189
+ test_prompts = [
190
+ ("What is your return policy?", "simple"),
191
+ ("Compare the settlement amounts for similar property damage claims in the Southeast region this quarter.", "medium"),
192
+ ("Analyze the multi-party liability exposure across claims #8901, #8902, and #8903 from the warehouse incident.", "complex"),
193
+ ]
194
+
195
+ for prompt_text, expected in test_prompts:
196
+ text = tokenizer.apply_chat_template([
197
+ {"role": "system", "content": SYSTEM_PROMPT},
198
+ {"role": "user", "content": prompt_text},
199
+ ], tokenize=False, add_generation_prompt=True)
200
+ output = model.fast_generate(
201
+ [text], sampling_params=sampling_params,
202
+ lora_request=model.load_lora("modelgate_nocot_test_lora"),
203
+ )[0].outputs[0].text
204
+ route = extract_route(output)
205
+ status = "βœ“" if route == expected else "βœ—"
206
+ print(f"{status} Expected: {expected:>7s} | Got: {str(route):>7s} | Raw: {output[:60]}")