File size: 16,229 Bytes
b259333
288d9a2
b259333
4e2b74e
 
 
b259333
4e2b74e
 
 
 
 
b259333
 
 
 
 
 
 
 
 
 
 
 
f703ff1
b259333
4ac72af
 
 
 
b259333
 
28bcb40
c2dc160
506d641
76f180f
b259333
4ac72af
b259333
 
 
 
 
 
f703ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10418d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc48b7
 
 
 
 
21da591
4ac72af
3dc48b7
 
 
10418d0
 
 
 
 
 
 
 
 
3dc48b7
 
b259333
3dc48b7
 
 
 
10418d0
3dc48b7
4ac72af
3dc48b7
 
 
 
10418d0
3dc48b7
21da591
 
 
 
10418d0
 
 
4ac72af
21da591
b259333
 
10418d0
4e2b74e
 
 
 
 
4b89b89
4e2b74e
4b89b89
4e2b74e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10418d0
 
 
4e2b74e
 
 
 
 
 
4b89b89
 
 
28bcb40
288d9a2
10418d0
f703ff1
 
 
 
 
 
 
3dc48b7
4e2b74e
 
b259333
76f180f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506d641
b259333
c2dc160
 
 
 
 
 
 
 
 
 
 
b259333
 
 
 
 
 
 
 
 
4e2b74e
 
 
b259333
 
71b0977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28bcb40
4e2b74e
506d641
 
 
 
 
4e2b74e
 
 
506d641
 
 
f703ff1
 
 
 
 
 
 
 
 
 
 
 
76f180f
 
28bcb40
76f180f
28bcb40
76f180f
 
 
28bcb40
 
76f180f
 
 
28bcb40
 
b259333
4e2b74e
b259333
4e2b74e
 
 
b259333
 
 
 
 
 
 
 
 
 
 
 
ee71a24
b259333
 
4e2b74e
288d9a2
b259333
4e2b74e
 
 
 
 
 
 
 
 
 
 
 
 
506d641
03d9529
 
 
 
 
 
 
 
 
 
b259333
 
4e2b74e
 
 
 
 
3dc48b7
 
28bcb40
4e2b74e
 
 
 
 
 
 
 
 
 
 
03d9529
 
 
 
 
 
 
 
4e2b74e
b259333
28bcb40
b259333
 
 
4e2b74e
 
b259333
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""
Layer 1 β€” GRPO training script for prompt optimization.

All parameters are loaded from config.yaml (single source of truth).
CLI flags override config.yaml values.

Usage:
    # Train with defaults from config.yaml
    python -m layer1.train

    # Override specific params
    python -m layer1.train --steps 20 --episodes 10

    # Evaluate a single prompt
    python -m layer1.train --mode eval --prompt "You are a helpful agent."
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
import os
from datetime import datetime

# Auto-load .env for HF_TOKEN
from dotenv import load_dotenv
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator, SFT_SEED_PROMPTS
from layer1.training_logger import TrainingLogger, ReportGenerator
from layer1.upload import SupabaseUploader
from layer2.customer_sim import CustomerPersona, CustomerSimulator
from layer2.hf_agent import HFAgent
from personas.generate_personas import generate_personas

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s")
logger = logging.getLogger(__name__)


def verify_volume_mount(paths_cfg: dict) -> None:
    """Write a canary file at startup to verify the volume is mounted and writable."""
    output_dirs = [
        paths_cfg.get("output_dir", ""),
        paths_cfg.get("log_dir", ""),
    ]
    for d in output_dirs:
        if not d:
            continue
        os.makedirs(d, exist_ok=True)
        canary = os.path.join(d, ".volume_check")
        try:
            with open(canary, "w") as f:
                f.write(f"volume check {datetime.now().isoformat()}\n")
                f.flush()
                os.fsync(f.fileno())
            logger.info("Volume check OK: %s", d)
        except OSError as e:
            logger.error("VOLUME WRITE FAILED for %s: %s", d, e)
            print(f"\n*** WARNING: Cannot write to {d} β€” volume may not be mounted! ***\n")
            raise


def _try_load_local_model(gen_cfg: dict, hf_token: str | None):
    """Try to load Llama locally; returns LocalLlamaModel or None."""
    backend = gen_cfg.get("inference_backend", "auto")
    if backend == "api":
        logger.info("inference_backend=api β€” using HF Inference API")
        return None

    try:
        import torch
        if not torch.cuda.is_available():
            if backend == "local":
                raise RuntimeError("inference_backend=local but no CUDA GPU available")
            logger.info("No CUDA GPU β€” falling back to HF Inference API")
            return None

        from layer2.local_model import get_shared_model
        model = get_shared_model(hf_token=hf_token)
        logger.info("Using local Llama model for Layer 2 inference")
        return model
    except ImportError:
        if backend == "local":
            raise RuntimeError(
                "inference_backend=local but torch/transformers not installed. "
                "Install with: pip install -e '.[train]'"
            )
        logger.info("torch/transformers not available β€” using HF Inference API")
        return None


def load_evaluator(
    hf_token: str | None = None,
    gen_cfg: dict | None = None,
    personas_cfg: dict | None = None,
) -> PromptEvaluator:
    """Load personas and create the evaluator with LLM agent."""
    token = hf_token or os.environ.get("HF_TOKEN")
    gen_cfg = gen_cfg or {}
    personas_cfg = personas_cfg or {}

    # Try local model first
    local_model = _try_load_local_model(gen_cfg, token)

    if local_model is None and not token:
        raise RuntimeError(
            "HF_TOKEN is required when not using local model. "
            "Set it via --hf-token or the HF_TOKEN environment variable."
        )

    persona_count = personas_cfg.get("count", 100)
    personas_data = generate_personas(persona_count)
    personas = [CustomerPersona(**p) for p in personas_data]
    simulator = CustomerSimulator(
        hf_token=token,
        max_tokens=gen_cfg.get("customer_max_tokens", 200),
        temperature=gen_cfg.get("customer_temperature", 0.7),
        local_model=local_model,
    )

    agent = HFAgent(
        hf_token=token,
        max_tokens=gen_cfg.get("agent_max_tokens", 300),
        temperature=gen_cfg.get("agent_temperature", 0.3),
        local_model=local_model,
    )
    if not agent.is_llm_available:
        raise RuntimeError(
            "LLM agent could not be initialized. Check your HF_TOKEN and huggingface_hub installation."
        )

    backend_name = "local model" if local_model else "HF Inference API"
    logger.info("Using LLM agent (Llama 3.1 8B via %s)", backend_name)

    return PromptEvaluator(personas=personas, simulator=simulator, agent_fn=agent)


def _print_config_banner(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, gen_cfg: dict | None = None):
    """Print all training parameters from config."""
    total_conversations = (
        config.num_training_steps * config.num_candidates * config.episodes_per_candidate
    )

    print(f"\n{'='*70}")
    print(f"  TRAINING CONFIGURATION (from config.yaml)")
    print(f"{'='*70}")
    print()
    print(f"  --- Layer 1: GRPO RL Training ---")
    print(f"  Prompt Generator Model:        {config.model_name}")
    print(f"  LoRA:                          r={config.lora_r}  alpha={config.lora_alpha}  dropout={config.lora_dropout}")
    print(f"  Learning Rate:                 {config.learning_rate:.1e}")
    print(f"  Steps / GRPO Iterations:       {config.num_training_steps}")
    print(f"  Candidates / Customer Reps:    {config.num_candidates} per step")
    print(f"  Episodes / Customers:          {config.episodes_per_candidate} per candidate")
    print(f"  Max Prompt Length:             {config.max_prompt_length} tokens")
    print(f"  Batch Size:                    {config.per_device_train_batch_size}")
    print(f"  Gradient Accumulation:         {config.gradient_accumulation_steps}")
    print()
    print(f"  --- Layer 2: Conversation Environment ---")
    print(f"  Domain:                        {config.domain}")
    print(f"  Intents:                       {config.intents}")
    print(f"  Max Turns per Conversation:    (from env config)")
    print(f"  Inference Backend:             {gen_cfg.get('inference_backend', 'auto') if gen_cfg else 'auto'}")
    print(f"  Customer Rep Agent:            Llama 3.1 8B")
    print(f"  Customer Simulator:            Llama 3.1 8B")
    print()
    print(f"  --- Totals ---")
    print(f"  Total LLM Conversations:       ~{total_conversations}")
    print(f"  Report Generation:             {'yes' if report_cfg['enabled'] else 'no'}")
    print(f"  Output Dir:                    {paths_cfg['output_dir']}")
    print(f"  Log Dir:                       {paths_cfg['log_dir']}")
    print(f"{'='*70}\n")


def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None, upload_cfg: dict | None = None):
    """Run GRPO training."""
    _print_config_banner(config, report_cfg, paths_cfg, gen_cfg=gen_cfg)

    # Verify volume is mounted before doing any expensive work
    all_paths = dict(paths_cfg)
    if report_cfg.get("enabled") and report_cfg.get("output_dir"):
        all_paths["report_dir"] = report_cfg["output_dir"]
    verify_volume_mount(all_paths)

    evaluator = load_evaluator(hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg)
    training_logger = TrainingLogger(
        log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
    )

    # Wire up incremental Supabase uploads
    upload_cfg = upload_cfg or {}
    uploader = None
    if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
        uploader = SupabaseUploader(
            run_id=training_logger.timestamp,
            bucket=upload_cfg.get("bucket", "training-results"),
            config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
        )
        if uploader.enabled:
            training_logger.add_on_step_callback(uploader.after_step)
            print("Supabase incremental upload enabled")
        else:
            uploader = None
    elif upload_cfg.get("enabled"):
        print("Supabase upload enabled but SUPABASE_URL not set β€” skipping")

    trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
    trainer.setup_model()

    # SFT warm start: prime the model on hand-crafted seed prompts before GRPO
    if config.sft_warm_start:
        print(f"\n{'='*60}")
        print("SFT WARM START")
        print(f"{'='*60}")
        print(f"  Seed prompts: {len(SFT_SEED_PROMPTS)}")
        print(f"  Epochs: {config.sft_epochs}  |  LR: {config.sft_lr:.1e}")
        print(f"{'='*60}\n")
        trainer.sft_warm_start(num_epochs=config.sft_epochs, sft_lr=config.sft_lr)

    trainer.train()

    best_prompt = trainer.generate_best_prompt()
    print(f"\n{'='*60}")
    print("TRAINED SYSTEM PROMPT")
    print(f"{'='*60}")
    print(best_prompt)

    # Evaluate the trained prompt
    result = evaluator.evaluate_prompt(
        best_prompt, num_episodes=config.episodes_per_candidate
    )
    print(f"\nEvaluation: mean_reward={result['mean_reward']:.1f}")

    # Always output raw training summary (arrays for post-hoc analysis)
    print(f"\n{'='*60}")
    print("RAW TRAINING SUMMARY")
    print(f"{'='*60}")
    raw_summary = training_logger.generate_raw_summary()
    summary_path = training_logger.save_raw_summary(paths_cfg.get("log_dir"))
    print(f"Saved to: {summary_path}")
    print(f"\nSteps:        {raw_summary['steps']}")
    print(f"Mean rewards: {raw_summary['mean_rewards']}")
    print(f"Min rewards:  {raw_summary['min_rewards']}")
    print(f"Max rewards:  {raw_summary['max_rewards']}")
    print(f"Best step:    {raw_summary['best_step']} (reward={raw_summary['best_mean_reward']})")
    print(f"Total episodes: {raw_summary['total_episodes']}")
    print(f"Duration:     {raw_summary['duration_seconds']}s")
    print(f"\nPer-step episode rewards:")
    for step, rewards in zip(raw_summary["steps"], raw_summary["all_episode_rewards"]):
        print(f"  Step {step:3d}: {rewards}")
    print(f"\nFull raw JSON: {summary_path}")
    print(f"{'='*60}")

    report_path = None
    if report_cfg["enabled"]:
        print(f"\n{'='*60}")
        print("GENERATING TRAINING REPORT...")
        print(f"{'='*60}")
        report_gen = ReportGenerator(evaluator, training_logger)
        report_path = report_gen.generate_report(
            output_dir=report_cfg["output_dir"],
            num_eval_episodes=report_cfg["eval_episodes"],
            num_example_customers=report_cfg["example_customers"],
        )
        print(f"\nReport saved to {report_path}")

        # Print report to stdout as fallback (always visible in logs)
        try:
            with open(report_path, "r") as f:
                report_content = f.read()
            print(f"\n{'='*60}")
            print("REPORT CONTENT (stdout fallback)")
            print(f"{'='*60}")
            print(report_content)
            print(f"{'='*60}")
        except OSError:
            print("WARNING: Could not re-read report from disk")

    # Finalize Supabase upload (update duration, upload files)
    if uploader and uploader.enabled:
        print(f"\n{'='*60}")
        print("FINALIZING SUPABASE UPLOAD...")
        print(f"{'='*60}")
        uploader.finish(
            duration_seconds=raw_summary.get("duration_seconds"),
            report_path=report_path,
            raw_summary=raw_summary,
        )
        print(f"  Run ID:  {uploader.run_id}")
        print(f"  Steps uploaded incrementally: {len(uploader._mean_rewards)}")
        print(f"  Episodes uploaded: {uploader._total_episodes}")
        print(f"{'='*60}")


def run_eval(hf_token: str | None, prompt: str, episodes: int):
    """Evaluate a single prompt."""
    evaluator = load_evaluator(hf_token)
    result = evaluator.evaluate_prompt(prompt, num_episodes=episodes)
    print(f"Prompt: {prompt[:80]}...")
    print(f"Mean reward: {result['mean_reward']:.1f}")
    print(f"Min/Max: {result['min_reward']:.1f} / {result['max_reward']:.1f}")

    for i, log in enumerate(result["logs"]):
        print(
            f"  Episode {i}: intent={log['true_intent']} "
            f"correct={log['intent_correct']} turns={log['turns']} "
            f"reward={result['rewards'][i]:.1f}"
        )


def main():
    print("Version: 0.0")
    parser = argparse.ArgumentParser(description="Layer 1 β€” GRPO Prompt Optimizer")
    parser.add_argument(
        "--mode", choices=["train", "eval"], default="train",
        help="Mode: train (GRPO RL training), eval (evaluate a single prompt)",
    )
    parser.add_argument("--config", type=str, default=None,
                        help="Path to config.yaml (default: ./config.yaml)")
    parser.add_argument("--episodes", type=int, default=None,
                        help="Override episodes_per_candidate from config")
    parser.add_argument("--steps", type=int, default=None,
                        help="Override num_training_steps from config")
    parser.add_argument("--output-dir", type=str, default=None,
                        help="Override output directory from config")
    parser.add_argument("--hf-token", type=str, default=None,
                        help="HuggingFace API token")
    parser.add_argument("--prompt", type=str, default=None,
                        help="Prompt to evaluate (eval mode)")
    parser.add_argument("--no-report", action="store_true",
                        help="Skip report generation")
    parser.add_argument("--report-dir", type=str, default=None,
                        help="Override report output directory from config")
    parser.add_argument("--log-dir", type=str, default=None,
                        help="Override log directory from config")
    parser.add_argument("--eval-episodes", type=int, default=None,
                        help="Override eval episodes for report from config")
    parser.add_argument("--example-customers", type=int, default=None,
                        help="Override example customers in report from config")
    parser.add_argument("--output", type=str, default=None,
                        help="Save results to JSON file")
    args = parser.parse_args()

    # Load config from YAML
    cfg = load_config(args.config)
    grpo_config = make_grpo_config(cfg)
    report_cfg = get_report_config(cfg)
    paths_cfg = get_paths(cfg)
    gen_cfg = get_generation_config(cfg)
    personas_cfg = get_personas_config(cfg)
    upload_cfg = get_upload_config(cfg)

    # CLI overrides
    if args.steps is not None:
        grpo_config.num_training_steps = args.steps
    if args.episodes is not None:
        grpo_config.episodes_per_candidate = args.episodes
    if args.output_dir is not None:
        grpo_config.output_dir = args.output_dir
        paths_cfg["output_dir"] = args.output_dir
    if args.no_report:
        report_cfg["enabled"] = False
    if args.report_dir is not None:
        report_cfg["output_dir"] = args.report_dir
    if args.log_dir is not None:
        paths_cfg["log_dir"] = args.log_dir
    if args.eval_episodes is not None:
        report_cfg["eval_episodes"] = args.eval_episodes
    if args.example_customers is not None:
        report_cfg["example_customers"] = args.example_customers

    if args.mode == "train":
        run_train(grpo_config, report_cfg, paths_cfg, args.hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg, upload_cfg=upload_cfg)
    elif args.mode == "eval":
        if not args.prompt:
            parser.error("--prompt is required for eval mode")
        episodes = args.episodes or grpo_config.episodes_per_candidate
        run_eval(args.hf_token, args.prompt, episodes)


if __name__ == "__main__":
    main()