Spaces:
Running on T4
Running on T4
Claude commited on
Add clear training progress logging with technical + domain names
Browse files- Startup banner: shows all config params (steps, episodes, models,
estimated total conversations)
- Per-step: [Step/Customer Rep X/Y] or [Step/GRPO Iteration X/Y]
- Per-episode: [Episode/Customer X/Y] with persona details and results
- Both technical names (step, episode, candidate) and domain names
(customer rep, customer, GRPO iteration) shown side-by-side
https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V
- layer1/grpo_trainer.py +69 -6
- layer1/train.py +28 -0
layer1/grpo_trainer.py
CHANGED
|
@@ -100,6 +100,7 @@ class PromptEvaluator:
|
|
| 100 |
system_prompt: str,
|
| 101 |
num_episodes: int = 10,
|
| 102 |
personas_subset: list[CustomerPersona] | None = None,
|
|
|
|
| 103 |
) -> dict[str, Any]:
|
| 104 |
"""
|
| 105 |
Run num_episodes conversations with the given system prompt.
|
|
@@ -112,7 +113,13 @@ class PromptEvaluator:
|
|
| 112 |
|
| 113 |
rewards = []
|
| 114 |
logs = []
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
log = self.env.run_episode(
|
| 117 |
system_prompt=system_prompt,
|
| 118 |
agent_fn=self.agent_fn,
|
|
@@ -121,9 +128,15 @@ class PromptEvaluator:
|
|
| 121 |
r = reward_fn(log)
|
| 122 |
rewards.append(r)
|
| 123 |
logs.append(log.to_dict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
|
|
|
| 125 |
return {
|
| 126 |
-
"mean_reward":
|
| 127 |
"total_reward": sum(rewards),
|
| 128 |
"min_reward": min(rewards) if rewards else 0.0,
|
| 129 |
"max_reward": max(rewards) if rewards else 0.0,
|
|
@@ -186,18 +199,35 @@ class GRPOPromptTrainer:
|
|
| 186 |
def _reward_function(self, completions, **kwargs):
|
| 187 |
"""GRPO reward: evaluate each generated system prompt in Layer 2."""
|
| 188 |
rewards = []
|
| 189 |
-
|
|
|
|
| 190 |
if isinstance(completion, list):
|
| 191 |
system_prompt = completion[0].get("content", str(completion))
|
| 192 |
else:
|
| 193 |
system_prompt = str(completion)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
result = self.evaluator.evaluate_prompt(
|
| 196 |
system_prompt,
|
| 197 |
num_episodes=self.config.episodes_per_candidate,
|
|
|
|
| 198 |
)
|
| 199 |
rewards.append(result["mean_reward"])
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
if self._logger:
|
| 203 |
self._logger.log_iteration(
|
|
@@ -248,7 +278,19 @@ class GRPOPromptTrainer:
|
|
| 248 |
tokenizer=self._tokenizer,
|
| 249 |
)
|
| 250 |
|
| 251 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
trainer.train()
|
| 253 |
|
| 254 |
# Save the trained model
|
|
@@ -334,16 +376,37 @@ class MockPromptOptimizer:
|
|
| 334 |
def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
|
| 335 |
"""Evaluate all candidate prompts and return the best one."""
|
| 336 |
self.results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
for i, prompt in enumerate(self.CANDIDATE_PROMPTS):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
result = self.evaluator.evaluate_prompt(
|
| 340 |
system_prompt=prompt,
|
| 341 |
num_episodes=num_episodes_per_prompt,
|
|
|
|
| 342 |
)
|
| 343 |
result["prompt"] = prompt
|
| 344 |
result["prompt_index"] = i
|
| 345 |
self.results.append(result)
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
if self._logger:
|
| 349 |
self._logger.log_iteration(step=i, prompt=prompt, eval_result=result)
|
|
|
|
| 100 |
system_prompt: str,
|
| 101 |
num_episodes: int = 10,
|
| 102 |
personas_subset: list[CustomerPersona] | None = None,
|
| 103 |
+
step_label: str = "",
|
| 104 |
) -> dict[str, Any]:
|
| 105 |
"""
|
| 106 |
Run num_episodes conversations with the given system prompt.
|
|
|
|
| 113 |
|
| 114 |
rewards = []
|
| 115 |
logs = []
|
| 116 |
+
total = min(num_episodes, len(personas_to_use))
|
| 117 |
+
for ei, persona in enumerate(personas_to_use[:num_episodes]):
|
| 118 |
+
logger.info(
|
| 119 |
+
"%s Episode/Customer %d/%d — persona=%d intent=%s SE=%s",
|
| 120 |
+
step_label, ei + 1, total,
|
| 121 |
+
persona.id, persona.true_intent, persona.social_engineering,
|
| 122 |
+
)
|
| 123 |
log = self.env.run_episode(
|
| 124 |
system_prompt=system_prompt,
|
| 125 |
agent_fn=self.agent_fn,
|
|
|
|
| 128 |
r = reward_fn(log)
|
| 129 |
rewards.append(r)
|
| 130 |
logs.append(log.to_dict())
|
| 131 |
+
logger.info(
|
| 132 |
+
"%s Episode/Customer %d/%d — reward=%.1f correct=%s turns=%d",
|
| 133 |
+
step_label, ei + 1, total,
|
| 134 |
+
r, log.intent_correct, log.turns,
|
| 135 |
+
)
|
| 136 |
|
| 137 |
+
mean_r = sum(rewards) / len(rewards) if rewards else 0.0
|
| 138 |
return {
|
| 139 |
+
"mean_reward": mean_r,
|
| 140 |
"total_reward": sum(rewards),
|
| 141 |
"min_reward": min(rewards) if rewards else 0.0,
|
| 142 |
"max_reward": max(rewards) if rewards else 0.0,
|
|
|
|
| 199 |
def _reward_function(self, completions, **kwargs):
|
| 200 |
"""GRPO reward: evaluate each generated system prompt in Layer 2."""
|
| 201 |
rewards = []
|
| 202 |
+
total_candidates = len(completions)
|
| 203 |
+
for ci, completion in enumerate(completions):
|
| 204 |
if isinstance(completion, list):
|
| 205 |
system_prompt = completion[0].get("content", str(completion))
|
| 206 |
else:
|
| 207 |
system_prompt = str(completion)
|
| 208 |
|
| 209 |
+
step_label = (
|
| 210 |
+
f"[Step/GRPO Iteration {self._current_step + 1}/{self.config.num_training_steps}]"
|
| 211 |
+
f"[Candidate/Customer Rep {ci + 1}/{total_candidates}]"
|
| 212 |
+
)
|
| 213 |
+
logger.info(
|
| 214 |
+
"%s Evaluating generated prompt (%d chars): %.80s%s",
|
| 215 |
+
step_label, len(system_prompt),
|
| 216 |
+
system_prompt, "..." if len(system_prompt) > 80 else "",
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
result = self.evaluator.evaluate_prompt(
|
| 220 |
system_prompt,
|
| 221 |
num_episodes=self.config.episodes_per_candidate,
|
| 222 |
+
step_label=step_label,
|
| 223 |
)
|
| 224 |
rewards.append(result["mean_reward"])
|
| 225 |
+
|
| 226 |
+
logger.info(
|
| 227 |
+
"%s Done — mean_reward=%.1f min=%.1f max=%.1f",
|
| 228 |
+
step_label, result["mean_reward"],
|
| 229 |
+
result["min_reward"], result["max_reward"],
|
| 230 |
+
)
|
| 231 |
|
| 232 |
if self._logger:
|
| 233 |
self._logger.log_iteration(
|
|
|
|
| 278 |
tokenizer=self._tokenizer,
|
| 279 |
)
|
| 280 |
|
| 281 |
+
logger.info(
|
| 282 |
+
"=== GRPO Training: %d Steps/GRPO Iterations × "
|
| 283 |
+
"%d Candidates/Customer Rep configs × "
|
| 284 |
+
"%d Episodes/Customers each ===",
|
| 285 |
+
self.config.num_training_steps,
|
| 286 |
+
self.config.num_candidates,
|
| 287 |
+
self.config.episodes_per_candidate,
|
| 288 |
+
)
|
| 289 |
+
logger.info(
|
| 290 |
+
"Model/Prompt Generator: %s | LoRA r=%d α=%d | LR=%.1e",
|
| 291 |
+
self.config.model_name, self.config.lora_r,
|
| 292 |
+
self.config.lora_alpha, self.config.learning_rate,
|
| 293 |
+
)
|
| 294 |
trainer.train()
|
| 295 |
|
| 296 |
# Save the trained model
|
|
|
|
| 376 |
def optimize(self, num_episodes_per_prompt: int = 10) -> dict[str, Any]:
|
| 377 |
"""Evaluate all candidate prompts and return the best one."""
|
| 378 |
self.results = []
|
| 379 |
+
total_prompts = len(self.CANDIDATE_PROMPTS)
|
| 380 |
+
|
| 381 |
+
logger.info(
|
| 382 |
+
"=== Mock Optimization: %d System Prompts/Customer Rep configs × "
|
| 383 |
+
"%d Episodes/Customers each ===",
|
| 384 |
+
total_prompts, num_episodes_per_prompt,
|
| 385 |
+
)
|
| 386 |
|
| 387 |
for i, prompt in enumerate(self.CANDIDATE_PROMPTS):
|
| 388 |
+
step_label = (
|
| 389 |
+
f"[Step/Customer Rep {i + 1}/{total_prompts}]"
|
| 390 |
+
)
|
| 391 |
+
logger.info(
|
| 392 |
+
"%s Evaluating system prompt (%d chars): %.80s%s",
|
| 393 |
+
step_label, len(prompt), prompt, "..." if len(prompt) > 80 else "",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
result = self.evaluator.evaluate_prompt(
|
| 397 |
system_prompt=prompt,
|
| 398 |
num_episodes=num_episodes_per_prompt,
|
| 399 |
+
step_label=step_label,
|
| 400 |
)
|
| 401 |
result["prompt"] = prompt
|
| 402 |
result["prompt_index"] = i
|
| 403 |
self.results.append(result)
|
| 404 |
+
|
| 405 |
+
logger.info(
|
| 406 |
+
"%s Done — mean_reward=%.1f min=%.1f max=%.1f",
|
| 407 |
+
step_label, result["mean_reward"],
|
| 408 |
+
result["min_reward"], result["max_reward"],
|
| 409 |
+
)
|
| 410 |
|
| 411 |
if self._logger:
|
| 412 |
self._logger.log_iteration(step=i, prompt=prompt, eval_result=result)
|
layer1/train.py
CHANGED
|
@@ -64,8 +64,35 @@ def load_evaluator(hf_token: str | None = None) -> PromptEvaluator:
|
|
| 64 |
return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent)
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def run_mock(args):
|
| 68 |
"""Run mock optimization with hand-written prompts."""
|
|
|
|
| 69 |
evaluator = load_evaluator(args.hf_token)
|
| 70 |
training_logger = TrainingLogger(
|
| 71 |
log_dir=args.log_dir,
|
|
@@ -102,6 +129,7 @@ def run_mock(args):
|
|
| 102 |
|
| 103 |
def run_train(args):
|
| 104 |
"""Run full GRPO training (requires GPU)."""
|
|
|
|
| 105 |
evaluator = load_evaluator(args.hf_token)
|
| 106 |
training_logger = TrainingLogger(log_dir=args.log_dir, total_steps=args.steps)
|
| 107 |
config = GRPOConfig(
|
|
|
|
| 64 |
return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent)
|
| 65 |
|
| 66 |
|
| 67 |
+
def _print_config_banner(mode: str, args):
|
| 68 |
+
"""Print training configuration with both technical and domain names."""
|
| 69 |
+
print(f"\n{'='*70}")
|
| 70 |
+
print(f" TRAINING CONFIGURATION")
|
| 71 |
+
print(f"{'='*70}")
|
| 72 |
+
print(f" Mode: {mode}")
|
| 73 |
+
if mode == "mock":
|
| 74 |
+
n_prompts = len(MockPromptOptimizer.CANDIDATE_PROMPTS)
|
| 75 |
+
print(f" Steps / System Prompts: {n_prompts} (hand-written)")
|
| 76 |
+
else:
|
| 77 |
+
print(f" Steps / GRPO Iterations: {args.steps}")
|
| 78 |
+
print(f" Candidates / Customer Reps: 4 per step (GRPO-generated)")
|
| 79 |
+
print(f" Episodes / Customers: {args.episodes} per prompt")
|
| 80 |
+
print(f" Customer Rep Agent: Llama 3.1 8B (HF Inference API)")
|
| 81 |
+
print(f" Customer Simulator: Llama 3.1 8B (HF Inference API)")
|
| 82 |
+
print(f" Total LLM conversations: ~{_estimate_conversations(mode, args)}")
|
| 83 |
+
print(f" Report generation: {'yes' if args.report else 'no'}")
|
| 84 |
+
print(f"{'='*70}\n")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _estimate_conversations(mode: str, args) -> int:
|
| 88 |
+
if mode == "mock":
|
| 89 |
+
return len(MockPromptOptimizer.CANDIDATE_PROMPTS) * args.episodes
|
| 90 |
+
return args.steps * 4 * args.episodes # steps × candidates × episodes
|
| 91 |
+
|
| 92 |
+
|
| 93 |
def run_mock(args):
|
| 94 |
"""Run mock optimization with hand-written prompts."""
|
| 95 |
+
_print_config_banner("mock", args)
|
| 96 |
evaluator = load_evaluator(args.hf_token)
|
| 97 |
training_logger = TrainingLogger(
|
| 98 |
log_dir=args.log_dir,
|
|
|
|
| 129 |
|
| 130 |
def run_train(args):
|
| 131 |
"""Run full GRPO training (requires GPU)."""
|
| 132 |
+
_print_config_banner("train", args)
|
| 133 |
evaluator = load_evaluator(args.hf_token)
|
| 134 |
training_logger = TrainingLogger(log_dir=args.log_dir, total_steps=args.steps)
|
| 135 |
config = GRPOConfig(
|