Claude commited on
Commit
4ac72af
·
unverified ·
1 Parent(s): b259333

Wire up real LLM integration via HF Inference API

Browse files

- Customer simulator and agent now use Llama 3.1 8B via HF Inference API
when HF_TOKEN is set in .env (gitignored, never pushed)
- Graceful fallback: if API credits deplete (402), auto-falls back to
rule-based simulation for remaining calls
- HFAgent uses Llama 3.1 (not Qwen which isn't available on free tier)
- A/B test supports --mode llm|rule flag, shows sample conversations
- Layer 1 train.py supports --llm-agent flag for real LLM evaluation
- Added python-dotenv + datasets to dependencies
- All .env loading via dotenv, keys never touch git
- 31 tests passing

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

layer1/train.py CHANGED
@@ -20,6 +20,10 @@ import logging
20
  import sys
21
  import os
22
 
 
 
 
 
23
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
 
25
  from layer1.grpo_trainer import (
@@ -30,23 +34,35 @@ from layer1.grpo_trainer import (
30
  build_meta_prompt,
31
  )
32
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
 
33
  from personas.generate_personas import generate_personas
34
 
35
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
36
  logger = logging.getLogger(__name__)
37
 
38
 
39
- def load_evaluator(hf_token: str | None = None) -> PromptEvaluator:
40
- """Load personas and create the evaluator."""
 
41
  personas_data = generate_personas(100)
42
  personas = [CustomerPersona(**p) for p in personas_data]
43
- simulator = CustomerSimulator(hf_token=hf_token)
44
- return PromptEvaluator(personas=personas, simulator=simulator)
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def run_mock(args):
48
  """Run mock optimization with hand-written prompts."""
49
- evaluator = load_evaluator(args.hf_token)
50
  optimizer = MockPromptOptimizer(evaluator)
51
  result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
52
 
@@ -66,7 +82,7 @@ def run_mock(args):
66
 
67
  def run_train(args):
68
  """Run full GRPO training (requires GPU)."""
69
- evaluator = load_evaluator(args.hf_token)
70
  config = GRPOConfig(
71
  num_training_steps=args.steps,
72
  episodes_per_candidate=args.episodes,
@@ -89,7 +105,7 @@ def run_train(args):
89
 
90
  def run_eval(args):
91
  """Evaluate a single prompt."""
92
- evaluator = load_evaluator(args.hf_token)
93
  result = evaluator.evaluate_prompt(args.prompt, num_episodes=args.episodes)
94
  print(f"Prompt: {args.prompt[:80]}...")
95
  print(f"Mean reward: {result['mean_reward']:.1f}")
@@ -118,6 +134,8 @@ def main():
118
  parser.add_argument("--output-dir", type=str, default="./grpo_output", help="Training output dir")
119
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token")
120
  parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
 
 
121
  args = parser.parse_args()
122
 
123
  if args.mode == "train":
 
20
  import sys
21
  import os
22
 
23
+ # Auto-load .env for HF_TOKEN
24
+ from dotenv import load_dotenv
25
+ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
26
+
27
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
28
 
29
  from layer1.grpo_trainer import (
 
34
  build_meta_prompt,
35
  )
36
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
37
+ from layer2.hf_agent import HFAgent
38
  from personas.generate_personas import generate_personas
39
 
40
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
41
  logger = logging.getLogger(__name__)
42
 
43
 
44
+ def load_evaluator(hf_token: str | None = None, use_llm_agent: bool = False) -> PromptEvaluator:
45
+ """Load personas and create the evaluator with optional LLM agent."""
46
+ token = hf_token or os.environ.get("HF_TOKEN")
47
  personas_data = generate_personas(100)
48
  personas = [CustomerPersona(**p) for p in personas_data]
49
+ simulator = CustomerSimulator(hf_token=token)
50
+
51
+ agent_fn = None
52
+ if use_llm_agent and token:
53
+ agent = HFAgent(hf_token=token)
54
+ if agent.is_llm_available:
55
+ agent_fn = agent
56
+ logger.info("Using LLM agent (Llama 3.1 8B)")
57
+ else:
58
+ logger.warning("LLM agent not available, using rule-based fallback")
59
+
60
+ return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent_fn)
61
 
62
 
63
  def run_mock(args):
64
  """Run mock optimization with hand-written prompts."""
65
+ evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
66
  optimizer = MockPromptOptimizer(evaluator)
67
  result = optimizer.optimize(num_episodes_per_prompt=args.episodes)
68
 
 
82
 
83
  def run_train(args):
84
  """Run full GRPO training (requires GPU)."""
85
+ evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
86
  config = GRPOConfig(
87
  num_training_steps=args.steps,
88
  episodes_per_candidate=args.episodes,
 
105
 
106
  def run_eval(args):
107
  """Evaluate a single prompt."""
108
+ evaluator = load_evaluator(args.hf_token, use_llm_agent=args.llm_agent)
109
  result = evaluator.evaluate_prompt(args.prompt, num_episodes=args.episodes)
110
  print(f"Prompt: {args.prompt[:80]}...")
111
  print(f"Mean reward: {result['mean_reward']:.1f}")
 
134
  parser.add_argument("--output-dir", type=str, default="./grpo_output", help="Training output dir")
135
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token")
136
  parser.add_argument("--prompt", type=str, default=None, help="Prompt to evaluate (eval mode)")
137
+ parser.add_argument("--llm-agent", action="store_true",
138
+ help="Use LLM (Llama 3.1) as the agent instead of rule-based")
139
  args = parser.parse_args()
140
 
141
  if args.mode == "train":
layer2/customer_sim.py CHANGED
@@ -80,7 +80,18 @@ class CustomerSimulator:
80
  ) -> str:
81
  """Generate the next customer reply given the conversation so far."""
82
  if self._client is not None:
83
- return self._generate_llm_reply(persona, conversation_history, agent_message)
 
 
 
 
 
 
 
 
 
 
 
84
  return self._generate_rule_reply(persona, conversation_history, agent_message)
85
 
86
  def _generate_llm_reply(
 
80
  ) -> str:
81
  """Generate the next customer reply given the conversation so far."""
82
  if self._client is not None:
83
+ try:
84
+ return self._generate_llm_reply(persona, conversation_history, agent_message)
85
+ except Exception as e:
86
+ if "402" in str(e) or "Payment Required" in str(e):
87
+ import logging
88
+ logging.getLogger(__name__).warning(
89
+ "HF API credits depleted, falling back to rule-based. "
90
+ "Get more credits at https://huggingface.co/settings/billing"
91
+ )
92
+ self._client = None # disable for remaining calls
93
+ else:
94
+ raise
95
  return self._generate_rule_reply(persona, conversation_history, agent_message)
96
 
97
  def _generate_llm_reply(
layer2/hf_agent.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  HF Inference API wrapper for the voice agent (Layer 2).
3
 
4
- Uses a small model via HF Inference to act as the customer support agent
5
- during evaluation. In training (Layer 1), the agent is the model being optimized.
 
6
  """
7
 
8
  from __future__ import annotations
@@ -21,11 +22,11 @@ class HFAgent:
21
  """
22
  Voice agent powered by HF Inference API.
23
 
24
- This wraps a small model (e.g. Qwen 2.5 3B) with a system prompt
25
- from Layer 1, and generates responses in the customer support conversation.
26
  """
27
 
28
- DEFAULT_MODEL = "Qwen/Qwen2.5-3B-Instruct"
29
 
30
  def __init__(self, model_id: str | None = None, hf_token: str | None = None):
31
  self.model_id = model_id or self.DEFAULT_MODEL
@@ -34,6 +35,10 @@ class HFAgent:
34
  if self.hf_token and InferenceClient is not None:
35
  self._client = InferenceClient(token=self.hf_token)
36
 
 
 
 
 
37
  def __call__(
38
  self,
39
  system_prompt: str,
@@ -46,7 +51,7 @@ class HFAgent:
46
  Compatible with ConversationEnvironment.run_episode(agent_fn=...).
47
  """
48
  if self._client is None:
49
- return self._fallback_response(observation)
50
 
51
  messages = [{"role": "system", "content": system_prompt}]
52
 
@@ -61,15 +66,26 @@ class HFAgent:
61
  if customer_msg:
62
  messages.append({"role": "user", "content": customer_msg})
63
 
64
- response = self._client.chat_completion(
65
- model=self.model_id,
66
- messages=messages,
67
- max_tokens=300,
68
- temperature=0.3,
69
- )
70
- return response.choices[0].message.content.strip()
71
-
72
- def _fallback_response(self, observation: dict[str, Any]) -> str:
 
 
 
 
 
 
 
 
 
 
 
73
  """Rule-based fallback when no HF token is available."""
74
  customer_msg = observation.get("customer_message", "").lower()
75
  intents = observation.get("intents", [])
 
1
  """
2
  HF Inference API wrapper for the voice agent (Layer 2).
3
 
4
+ Uses Llama 3.1 8B Instruct via HF Inference to act as the customer support
5
+ agent during evaluation. In training (Layer 1), the agent is the model being
6
+ optimized — this module provides the inference-time agent for A/B testing.
7
  """
8
 
9
  from __future__ import annotations
 
22
  """
23
  Voice agent powered by HF Inference API.
24
 
25
+ Takes a system prompt from Layer 1 and generates responses
26
+ in the customer support conversation using Llama 3.1 8B.
27
  """
28
 
29
+ DEFAULT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
30
 
31
  def __init__(self, model_id: str | None = None, hf_token: str | None = None):
32
  self.model_id = model_id or self.DEFAULT_MODEL
 
35
  if self.hf_token and InferenceClient is not None:
36
  self._client = InferenceClient(token=self.hf_token)
37
 
38
+ @property
39
+ def is_llm_available(self) -> bool:
40
+ return self._client is not None
41
+
42
  def __call__(
43
  self,
44
  system_prompt: str,
 
51
  Compatible with ConversationEnvironment.run_episode(agent_fn=...).
52
  """
53
  if self._client is None:
54
+ return self._fallback_response(system_prompt, observation)
55
 
56
  messages = [{"role": "system", "content": system_prompt}]
57
 
 
66
  if customer_msg:
67
  messages.append({"role": "user", "content": customer_msg})
68
 
69
+ try:
70
+ response = self._client.chat_completion(
71
+ model=self.model_id,
72
+ messages=messages,
73
+ max_tokens=300,
74
+ temperature=0.3,
75
+ )
76
+ return response.choices[0].message.content.strip()
77
+ except Exception as e:
78
+ if "402" in str(e) or "Payment Required" in str(e):
79
+ import logging
80
+ logging.getLogger(__name__).warning(
81
+ "HF API credits depleted, falling back to rule-based. "
82
+ "Get more credits at https://huggingface.co/settings/billing"
83
+ )
84
+ self._client = None
85
+ return self._fallback_response(system_prompt, observation)
86
+ raise
87
+
88
+ def _fallback_response(self, system_prompt: str, observation: dict[str, Any]) -> str:
89
  """Rule-based fallback when no HF token is available."""
90
  customer_msg = observation.get("customer_message", "").lower()
91
  intents = observation.get("intents", [])
pyproject.toml CHANGED
@@ -12,6 +12,7 @@ dependencies = [
12
  "huggingface-hub>=0.20.0",
13
  "requests>=2.31.0",
14
  "pydantic>=2.0",
 
15
  "gradio>=4.0.0",
16
  ]
17
 
@@ -24,6 +25,7 @@ train = [
24
  "peft>=0.9.0",
25
  "bitsandbytes>=0.43.0",
26
  "accelerate>=0.27.0",
 
27
  ]
28
  dev = [
29
  "pytest>=8.0",
 
12
  "huggingface-hub>=0.20.0",
13
  "requests>=2.31.0",
14
  "pydantic>=2.0",
15
+ "python-dotenv>=1.0.0",
16
  "gradio>=4.0.0",
17
  ]
18
 
 
25
  "peft>=0.9.0",
26
  "bitsandbytes>=0.43.0",
27
  "accelerate>=0.27.0",
28
+ "datasets>=2.18.0",
29
  ]
30
  dev = [
31
  "pytest>=8.0",
scripts/ab_test.py CHANGED
@@ -1,8 +1,11 @@
1
  """
2
- A/B Test: Compare base prompt vs trained/optimized prompt on 100 simulated customers.
 
 
 
3
 
4
  Usage:
5
- python -m scripts.ab_test [--episodes 100] [--hf-token TOKEN]
6
  """
7
 
8
  from __future__ import annotations
@@ -12,11 +15,16 @@ import json
12
  import sys
13
  import os
14
 
 
 
 
 
15
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16
 
17
  from layer0.reward import reward_fn, BANKING_INTENTS
18
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
19
  from layer2.environment import ConversationEnvironment, EnvConfig
 
20
  from personas.generate_personas import generate_personas
21
 
22
 
@@ -42,16 +50,34 @@ TRAINED_PROMPT = (
42
 
43
 
44
  def run_ab_test(
45
- num_episodes: int = 100,
46
  hf_token: str | None = None,
 
47
  ) -> dict:
48
- """Run A/B test comparing base vs trained prompt."""
 
 
 
 
 
 
 
 
 
49
  # Load personas
50
  personas_data = generate_personas(num_episodes)
51
  personas = [CustomerPersona(**p) for p in personas_data]
52
 
53
- # Initialize simulator
54
- simulator = CustomerSimulator(hf_token=hf_token)
 
 
 
 
 
 
 
 
55
 
56
  # Create environment
57
  env = ConversationEnvironment(
@@ -73,9 +99,17 @@ def run_ab_test(
73
  correct = 0
74
  injection_resisted = 0
75
  injection_total = 0
 
76
 
77
  for i, persona in enumerate(personas):
78
- log = env.run_episode(system_prompt=prompt, persona=persona)
 
 
 
 
 
 
 
79
  r = reward_fn(log)
80
  rewards.append(r)
81
  turns_list.append(log.turns)
@@ -88,7 +122,20 @@ def run_ab_test(
88
  if not log.injection_succeeded:
89
  injection_resisted += 1
90
 
91
- if (i + 1) % 25 == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  print(f" [{i+1}/{num_episodes}] avg_reward={sum(rewards)/len(rewards):.1f}")
93
 
94
  results[label] = {
@@ -101,6 +148,8 @@ def run_ab_test(
101
  "min_reward": min(rewards),
102
  "max_reward": max(rewards),
103
  "total_episodes": num_episodes,
 
 
104
  }
105
 
106
  return results
@@ -112,6 +161,10 @@ def print_results(results: dict):
112
  print("=" * 62)
113
  print(f"{'A/B TEST RESULTS':^62}")
114
  print("=" * 62)
 
 
 
 
115
  print(f"{'Metric':<25} {'Base Prompt':>15} {'Trained Prompt':>18}")
116
  print("-" * 62)
117
 
@@ -129,27 +182,49 @@ def print_results(results: dict):
129
  print(f"{name:<25} {b_val:>15} {t_val:>18}")
130
 
131
  print("=" * 62)
132
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  def main():
136
  parser = argparse.ArgumentParser(description="A/B test: base vs trained prompt")
137
- parser.add_argument("--episodes", type=int, default=100, help="Number of episodes per prompt")
138
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token")
 
 
139
  parser.add_argument("--output", type=str, default=None, help="Save results to JSON file")
140
  args = parser.parse_args()
141
 
142
  results = run_ab_test(
143
  num_episodes=args.episodes,
144
  hf_token=args.hf_token,
 
145
  )
146
 
147
  print_results(results)
148
 
149
  if args.output:
 
 
 
150
  with open(args.output, "w") as f:
151
  json.dump(results, f, indent=2)
152
- print(f"Results saved to {args.output}")
153
 
154
 
155
  if __name__ == "__main__":
 
1
  """
2
+ A/B Test: Compare base prompt vs trained/optimized prompt.
3
+
4
+ Uses real LLM (Llama 3.1 8B via HF Inference API) for both
5
+ the customer simulator and the voice agent when HF_TOKEN is set.
6
 
7
  Usage:
8
+ python -m scripts.ab_test [--episodes 10] [--mode llm|rule]
9
  """
10
 
11
  from __future__ import annotations
 
15
  import sys
16
  import os
17
 
18
+ # Auto-load .env
19
+ from dotenv import load_dotenv
20
+ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
21
+
22
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
24
  from layer0.reward import reward_fn, BANKING_INTENTS
25
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
26
  from layer2.environment import ConversationEnvironment, EnvConfig
27
+ from layer2.hf_agent import HFAgent
28
  from personas.generate_personas import generate_personas
29
 
30
 
 
50
 
51
 
52
  def run_ab_test(
53
+ num_episodes: int = 10,
54
  hf_token: str | None = None,
55
+ mode: str = "llm",
56
  ) -> dict:
57
+ """
58
+ Run A/B test comparing base vs trained prompt.
59
+
60
+ Args:
61
+ num_episodes: Number of episodes per prompt
62
+ hf_token: HuggingFace API token (auto-loaded from .env if not provided)
63
+ mode: "llm" for real LLM agent+customer, "rule" for rule-based fallback
64
+ """
65
+ token = hf_token or os.environ.get("HF_TOKEN")
66
+
67
  # Load personas
68
  personas_data = generate_personas(num_episodes)
69
  personas = [CustomerPersona(**p) for p in personas_data]
70
 
71
+ # Initialize simulator (uses LLM if token available)
72
+ simulator = CustomerSimulator(hf_token=token if mode == "llm" else None)
73
+
74
+ # Initialize LLM agent (uses LLM if token available)
75
+ agent = HFAgent(hf_token=token if mode == "llm" else None)
76
+
77
+ using_llm = mode == "llm" and agent.is_llm_available
78
+ print(f"Mode: {'LLM (Llama 3.1 8B)' if using_llm else 'Rule-based'}")
79
+ print(f"Customer sim: {'LLM' if simulator._client else 'Rule-based'}")
80
+ print(f"Agent: {'LLM' if agent.is_llm_available else 'Rule-based'}")
81
 
82
  # Create environment
83
  env = ConversationEnvironment(
 
99
  correct = 0
100
  injection_resisted = 0
101
  injection_total = 0
102
+ sample_conversations = []
103
 
104
  for i, persona in enumerate(personas):
105
+ # Use LLM agent if available, otherwise default rule-based
106
+ agent_fn = agent if using_llm else None
107
+
108
+ log = env.run_episode(
109
+ system_prompt=prompt,
110
+ agent_fn=agent_fn,
111
+ persona=persona,
112
+ )
113
  r = reward_fn(log)
114
  rewards.append(r)
115
  turns_list.append(log.turns)
 
122
  if not log.injection_succeeded:
123
  injection_resisted += 1
124
 
125
+ # Save first 3 conversations for inspection
126
+ if len(sample_conversations) < 3:
127
+ sample_conversations.append({
128
+ "persona_id": persona.id,
129
+ "true_intent": persona.true_intent,
130
+ "social_engineering": persona.social_engineering,
131
+ "messages": log.messages if hasattr(log, "messages") else [],
132
+ "reward": r,
133
+ "intent_correct": log.intent_correct,
134
+ "injection_succeeded": log.injection_succeeded,
135
+ "turns": log.turns,
136
+ })
137
+
138
+ if (i + 1) % max(1, num_episodes // 4) == 0:
139
  print(f" [{i+1}/{num_episodes}] avg_reward={sum(rewards)/len(rewards):.1f}")
140
 
141
  results[label] = {
 
148
  "min_reward": min(rewards),
149
  "max_reward": max(rewards),
150
  "total_episodes": num_episodes,
151
+ "mode": "llm" if using_llm else "rule",
152
+ "sample_conversations": sample_conversations,
153
  }
154
 
155
  return results
 
161
  print("=" * 62)
162
  print(f"{'A/B TEST RESULTS':^62}")
163
  print("=" * 62)
164
+
165
+ mode = results.get("base", {}).get("mode", "unknown")
166
+ print(f"{'Mode: ' + mode:^62}")
167
+ print("-" * 62)
168
  print(f"{'Metric':<25} {'Base Prompt':>15} {'Trained Prompt':>18}")
169
  print("-" * 62)
170
 
 
182
  print(f"{name:<25} {b_val:>15} {t_val:>18}")
183
 
184
  print("=" * 62)
185
+
186
+ # Print sample conversations
187
+ for label in ["base", "trained"]:
188
+ samples = results[label].get("sample_conversations", [])
189
+ if samples:
190
+ print(f"\n--- Sample conversations ({label.upper()}) ---")
191
+ for conv in samples[:2]:
192
+ print(f" Persona {conv['persona_id']} ({conv['true_intent']}, "
193
+ f"SE={conv['social_engineering']})")
194
+ for msg in conv.get("messages", []):
195
+ if isinstance(msg, dict):
196
+ role = "Customer" if msg.get("role") == "customer" else "Agent"
197
+ text = msg.get("content", "")[:120]
198
+ print(f" [{role}] {text}")
199
+ print(f" => reward={conv['reward']:.1f} correct={conv['intent_correct']} "
200
+ f"injection={conv['injection_succeeded']}")
201
+ print()
202
 
203
 
204
  def main():
205
  parser = argparse.ArgumentParser(description="A/B test: base vs trained prompt")
206
+ parser.add_argument("--episodes", type=int, default=10, help="Number of episodes per prompt")
207
  parser.add_argument("--hf-token", type=str, default=None, help="HuggingFace API token")
208
+ parser.add_argument("--mode", choices=["llm", "rule"], default="llm",
209
+ help="llm=real LLM agent+customer, rule=rule-based fallback")
210
  parser.add_argument("--output", type=str, default=None, help="Save results to JSON file")
211
  args = parser.parse_args()
212
 
213
  results = run_ab_test(
214
  num_episodes=args.episodes,
215
  hf_token=args.hf_token,
216
+ mode=args.mode,
217
  )
218
 
219
  print_results(results)
220
 
221
  if args.output:
222
+ # Remove non-serializable data
223
+ for label in results:
224
+ results[label].pop("sample_conversations", None)
225
  with open(args.output, "w") as f:
226
  json.dump(results, f, indent=2)
227
+ print(f"\nResults saved to {args.output}")
228
 
229
 
230
  if __name__ == "__main__":