Claude commited on
Commit
28bcb40
·
unverified ·
1 Parent(s): 71b0977

Add Supabase upload for training results (Storage + DB)

Browse files

- layer1/upload.py: uploads raw summary JSON, report, and chart to
Supabase Storage; inserts per-run and per-episode metrics into
training_runs and training_episodes Postgres tables
- scripts/supabase_setup.sql: migration to create tables, indexes,
and RLS policies — run in Supabase SQL Editor before first training
- config.yaml: upload section with enabled flag and bucket name
- config_loader.py: get_upload_config() for new section
- pyproject.toml: supabase>=2.0.0 as optional [upload] dependency

Requires SUPABASE_URL and SUPABASE_KEY env vars. Gracefully skips
upload if not configured (logs warning, training still completes).

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

config.yaml CHANGED
@@ -97,6 +97,15 @@ report:
97
  example_customers: 5 # Example conversations in report
98
 
99
 
 
 
 
 
 
 
 
 
 
100
  # --- Paths ---
101
 
102
  paths:
 
97
  example_customers: 5 # Example conversations in report
98
 
99
 
100
+ # --- Upload: Supabase ---
101
+ # Upload training results to Supabase for analysis.
102
+ # Requires SUPABASE_URL and SUPABASE_KEY environment variables.
103
+
104
+ upload:
105
+ enabled: true
106
+ bucket: "training-results" # Supabase Storage bucket name
107
+
108
+
109
  # --- Paths ---
110
 
111
  paths:
config_loader.py CHANGED
@@ -126,6 +126,15 @@ def get_generation_config(cfg: dict[str, Any]) -> dict[str, Any]:
126
  }
127
 
128
 
 
 
 
 
 
 
 
 
 
129
  def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]:
130
  """Extract persona settings from config."""
131
  personas = cfg.get("personas", {})
 
126
  }
127
 
128
 
129
+ def get_upload_config(cfg: dict[str, Any]) -> dict[str, Any]:
130
+ """Extract Supabase upload settings from config."""
131
+ upload = cfg.get("upload", {})
132
+ return {
133
+ "enabled": upload.get("enabled", False),
134
+ "bucket": upload.get("bucket", "training-results"),
135
+ }
136
+
137
+
138
  def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]:
139
  """Extract persona settings from config."""
140
  personas = cfg.get("personas", {})
layer1/train.py CHANGED
@@ -30,9 +30,10 @@ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file_
30
 
31
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
32
 
33
- from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config
34
  from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
35
  from layer1.training_logger import TrainingLogger, ReportGenerator
 
36
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
37
  from layer2.hf_agent import HFAgent
38
  from personas.generate_personas import generate_personas
@@ -138,7 +139,7 @@ def _print_config_banner(config: GRPOConfig, report_cfg: dict, paths_cfg: dict):
138
  print(f"{'='*70}\n")
139
 
140
 
141
- def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None):
142
  """Run GRPO training."""
143
  _print_config_banner(config, report_cfg, paths_cfg)
144
 
@@ -188,6 +189,7 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
188
  print(f"\nFull raw JSON: {summary_path}")
189
  print(f"{'='*60}")
190
 
 
191
  if report_cfg["enabled"]:
192
  print(f"\n{'='*60}")
193
  print("GENERATING TRAINING REPORT...")
@@ -212,6 +214,29 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
212
  except OSError:
213
  print("WARNING: Could not re-read report from disk")
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
217
  """Evaluate a single prompt."""
@@ -268,6 +293,7 @@ def main():
268
  paths_cfg = get_paths(cfg)
269
  gen_cfg = get_generation_config(cfg)
270
  personas_cfg = get_personas_config(cfg)
 
271
 
272
  # CLI overrides
273
  if args.steps is not None:
@@ -289,7 +315,7 @@ def main():
289
  report_cfg["example_customers"] = args.example_customers
290
 
291
  if args.mode == "train":
292
- run_train(grpo_config, report_cfg, paths_cfg, args.hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg)
293
  elif args.mode == "eval":
294
  if not args.prompt:
295
  parser.error("--prompt is required for eval mode")
 
30
 
31
  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
32
 
33
+ from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
34
  from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
35
  from layer1.training_logger import TrainingLogger, ReportGenerator
36
+ from layer1.upload import upload_training_results
37
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
38
  from layer2.hf_agent import HFAgent
39
  from personas.generate_personas import generate_personas
 
139
  print(f"{'='*70}\n")
140
 
141
 
142
+ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None, upload_cfg: dict | None = None):
143
  """Run GRPO training."""
144
  _print_config_banner(config, report_cfg, paths_cfg)
145
 
 
189
  print(f"\nFull raw JSON: {summary_path}")
190
  print(f"{'='*60}")
191
 
192
+ report_path = None
193
  if report_cfg["enabled"]:
194
  print(f"\n{'='*60}")
195
  print("GENERATING TRAINING REPORT...")
 
214
  except OSError:
215
  print("WARNING: Could not re-read report from disk")
216
 
217
+ # Upload to Supabase if configured
218
+ upload_cfg = upload_cfg or {}
219
+ if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
220
+ print(f"\n{'='*60}")
221
+ print("UPLOADING TO SUPABASE...")
222
+ print(f"{'='*60}")
223
+ upload_result = upload_training_results(
224
+ raw_summary=raw_summary,
225
+ run_id=training_logger.timestamp,
226
+ bucket=upload_cfg.get("bucket", "training-results"),
227
+ report_path=report_path if report_cfg["enabled"] else None,
228
+ chart_path=None, # chart path is internal to ReportGenerator
229
+ config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
230
+ )
231
+ print(f" Run ID: {upload_result['run_id']}")
232
+ print(f" Files: {len(upload_result['storage_paths'])} uploaded")
233
+ print(f" DB rows: {upload_result['db_rows']}")
234
+ if upload_result.get("error"):
235
+ print(f" Error: {upload_result['error']}")
236
+ print(f"{'='*60}")
237
+ elif upload_cfg.get("enabled"):
238
+ print("\nSupabase upload enabled but SUPABASE_URL not set — skipping")
239
+
240
 
241
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
242
  """Evaluate a single prompt."""
 
293
  paths_cfg = get_paths(cfg)
294
  gen_cfg = get_generation_config(cfg)
295
  personas_cfg = get_personas_config(cfg)
296
+ upload_cfg = get_upload_config(cfg)
297
 
298
  # CLI overrides
299
  if args.steps is not None:
 
315
  report_cfg["example_customers"] = args.example_customers
316
 
317
  if args.mode == "train":
318
+ run_train(grpo_config, report_cfg, paths_cfg, args.hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg, upload_cfg=upload_cfg)
319
  elif args.mode == "eval":
320
  if not args.prompt:
321
  parser.error("--prompt is required for eval mode")
layer1/upload.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supabase uploader for training results.
3
+
4
+ Uploads:
5
+ 1. Raw summary JSON + report files to Supabase Storage
6
+ 2. Per-run and per-episode metrics to Postgres tables
7
+
8
+ Requires SUPABASE_URL and SUPABASE_KEY environment variables.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ from datetime import datetime, timezone
17
+ from typing import Any
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _get_client():
23
+ """Create a Supabase client from environment variables."""
24
+ try:
25
+ from supabase import create_client
26
+ except ImportError:
27
+ logger.error(
28
+ "supabase package not installed. Install with: pip install 'nested-rl-envs[upload]'"
29
+ )
30
+ return None
31
+
32
+ url = os.environ.get("SUPABASE_URL")
33
+ key = os.environ.get("SUPABASE_KEY")
34
+ if not url or not key:
35
+ logger.error("SUPABASE_URL and SUPABASE_KEY must be set")
36
+ return None
37
+
38
+ return create_client(url, key)
39
+
40
+
41
+ def upload_training_results(
42
+ raw_summary: dict[str, Any],
43
+ run_id: str | None = None,
44
+ bucket: str = "training-results",
45
+ report_path: str | None = None,
46
+ chart_path: str | None = None,
47
+ config: dict[str, Any] | None = None,
48
+ ) -> dict[str, Any]:
49
+ """
50
+ Upload training results to Supabase (Storage + DB).
51
+
52
+ Args:
53
+ raw_summary: Output of TrainingLogger.generate_raw_summary().
54
+ run_id: Unique run identifier. Auto-generated if not provided.
55
+ bucket: Supabase Storage bucket name.
56
+ report_path: Path to the markdown report file (optional).
57
+ chart_path: Path to the reward chart PNG (optional).
58
+ config: Training config dict to store with the run (optional).
59
+
60
+ Returns:
61
+ Dict with upload results: {"run_id", "storage_paths", "db_rows"}.
62
+ """
63
+ client = _get_client()
64
+ if client is None:
65
+ logger.warning("Supabase upload skipped — client not available")
66
+ return {"run_id": None, "storage_paths": [], "db_rows": 0, "error": "no client"}
67
+
68
+ if run_id is None:
69
+ run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
70
+
71
+ results: dict[str, Any] = {"run_id": run_id, "storage_paths": [], "db_rows": 0}
72
+
73
+ # --- Storage uploads ---
74
+ results["storage_paths"] = _upload_files(
75
+ client, bucket, run_id, raw_summary, report_path, chart_path
76
+ )
77
+
78
+ # --- DB inserts ---
79
+ results["db_rows"] = _insert_metrics(client, run_id, raw_summary, config)
80
+
81
+ logger.info(
82
+ "Supabase upload complete: run_id=%s, files=%d, db_rows=%d",
83
+ run_id, len(results["storage_paths"]), results["db_rows"],
84
+ )
85
+ return results
86
+
87
+
88
+ def _upload_files(
89
+ client,
90
+ bucket: str,
91
+ run_id: str,
92
+ raw_summary: dict[str, Any],
93
+ report_path: str | None,
94
+ chart_path: str | None,
95
+ ) -> list[str]:
96
+ """Upload files to Supabase Storage."""
97
+ uploaded = []
98
+
99
+ # Upload raw summary JSON
100
+ try:
101
+ summary_bytes = json.dumps(raw_summary, indent=2, default=str).encode()
102
+ path = f"{run_id}/raw_summary.json"
103
+ client.storage.from_(bucket).upload(
104
+ path, summary_bytes, {"content-type": "application/json"}
105
+ )
106
+ uploaded.append(path)
107
+ logger.info("Uploaded %s to storage", path)
108
+ except Exception as e:
109
+ logger.error("Failed to upload raw_summary.json: %s", e)
110
+
111
+ # Upload report markdown
112
+ if report_path and os.path.exists(report_path):
113
+ try:
114
+ with open(report_path, "rb") as f:
115
+ path = f"{run_id}/report.md"
116
+ client.storage.from_(bucket).upload(
117
+ path, f.read(), {"content-type": "text/markdown"}
118
+ )
119
+ uploaded.append(path)
120
+ logger.info("Uploaded %s to storage", path)
121
+ except Exception as e:
122
+ logger.error("Failed to upload report: %s", e)
123
+
124
+ # Upload chart PNG
125
+ if chart_path and os.path.exists(chart_path):
126
+ try:
127
+ with open(chart_path, "rb") as f:
128
+ path = f"{run_id}/reward_chart.png"
129
+ client.storage.from_(bucket).upload(
130
+ path, f.read(), {"content-type": "image/png"}
131
+ )
132
+ uploaded.append(path)
133
+ logger.info("Uploaded %s to storage", path)
134
+ except Exception as e:
135
+ logger.error("Failed to upload chart: %s", e)
136
+
137
+ return uploaded
138
+
139
+
140
+ def _insert_metrics(
141
+ client,
142
+ run_id: str,
143
+ raw_summary: dict[str, Any],
144
+ config: dict[str, Any] | None,
145
+ ) -> int:
146
+ """Insert training run + per-episode metrics into Postgres tables."""
147
+ rows_inserted = 0
148
+
149
+ # Insert training run summary
150
+ try:
151
+ run_row = {
152
+ "run_id": run_id,
153
+ "started_at": datetime.now(timezone.utc).isoformat(),
154
+ "duration_seconds": raw_summary.get("duration_seconds"),
155
+ "total_steps": len(raw_summary.get("steps", [])),
156
+ "total_episodes": raw_summary.get("total_episodes", 0),
157
+ "best_step": raw_summary.get("best_step"),
158
+ "best_mean_reward": raw_summary.get("best_mean_reward"),
159
+ "mean_rewards": raw_summary.get("mean_rewards", []),
160
+ "min_rewards": raw_summary.get("min_rewards", []),
161
+ "max_rewards": raw_summary.get("max_rewards", []),
162
+ "config": config,
163
+ }
164
+ client.table("training_runs").insert(run_row).execute()
165
+ rows_inserted += 1
166
+ logger.info("Inserted training run: %s", run_id)
167
+ except Exception as e:
168
+ logger.error("Failed to insert training_runs row: %s", e)
169
+
170
+ # Insert per-episode metrics in batches
171
+ episode_rows = []
172
+ for m in raw_summary.get("per_episode_metrics", []):
173
+ episode_rows.append({
174
+ "run_id": run_id,
175
+ "step": m["step"],
176
+ "episode": m["episode"],
177
+ "reward": m.get("reward"),
178
+ "turns": m.get("turns", 0),
179
+ "intent_captured": m.get("intent_captured", False),
180
+ "intent_correct": m.get("intent_correct", False),
181
+ "true_intent": m.get("true_intent", ""),
182
+ "agent_intent": m.get("agent_intent", ""),
183
+ "injection_attempted": m.get("injection_attempted", False),
184
+ "injection_succeeded": m.get("injection_succeeded", False),
185
+ "api_call_made": m.get("api_call_made", False),
186
+ "api_call_correct": m.get("api_call_correct", False),
187
+ })
188
+
189
+ # Batch insert (Supabase/PostgREST supports bulk inserts)
190
+ if episode_rows:
191
+ batch_size = 100
192
+ for i in range(0, len(episode_rows), batch_size):
193
+ batch = episode_rows[i : i + batch_size]
194
+ try:
195
+ client.table("training_episodes").insert(batch).execute()
196
+ rows_inserted += len(batch)
197
+ except Exception as e:
198
+ logger.error("Failed to insert episode batch %d: %s", i, e)
199
+
200
+ return rows_inserted
pyproject.toml CHANGED
@@ -32,6 +32,9 @@ train = [
32
  "accelerate>=0.27.0",
33
  "datasets>=2.18.0",
34
  ]
 
 
 
35
  dev = [
36
  "pytest>=8.0",
37
  "ruff>=0.3.0",
 
32
  "accelerate>=0.27.0",
33
  "datasets>=2.18.0",
34
  ]
35
+ upload = [
36
+ "supabase>=2.0.0",
37
+ ]
38
  dev = [
39
  "pytest>=8.0",
40
  "ruff>=0.3.0",
scripts/supabase_setup.sql ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Supabase schema for training results
2
+ -- Run this in your Supabase SQL Editor to create the tables.
3
+
4
+ -- Training runs: one row per training run
5
+ create table if not exists training_runs (
6
+ id bigint generated always as identity primary key,
7
+ run_id text unique not null,
8
+ started_at timestamptz default now(),
9
+ duration_seconds real,
10
+ total_steps int,
11
+ total_episodes int,
12
+ best_step int,
13
+ best_mean_reward real,
14
+ mean_rewards jsonb, -- array of mean rewards per step
15
+ min_rewards jsonb, -- array of min rewards per step
16
+ max_rewards jsonb, -- array of max rewards per step
17
+ config jsonb, -- full training config snapshot
18
+ created_at timestamptz default now()
19
+ );
20
+
21
+ -- Per-episode metrics: one row per episode
22
+ create table if not exists training_episodes (
23
+ id bigint generated always as identity primary key,
24
+ run_id text not null references training_runs(run_id) on delete cascade,
25
+ step int not null,
26
+ episode int not null,
27
+ reward real,
28
+ turns int,
29
+ intent_captured boolean default false,
30
+ intent_correct boolean default false,
31
+ true_intent text,
32
+ agent_intent text,
33
+ injection_attempted boolean default false,
34
+ injection_succeeded boolean default false,
35
+ api_call_made boolean default false,
36
+ api_call_correct boolean default false,
37
+ created_at timestamptz default now()
38
+ );
39
+
40
+ -- Index for fast queries by run
41
+ create index if not exists idx_episodes_run_id on training_episodes(run_id);
42
+ create index if not exists idx_episodes_step on training_episodes(run_id, step);
43
+
44
+ -- Create the storage bucket (run via Supabase Dashboard > Storage > New Bucket)
45
+ -- Bucket name: training-results
46
+ -- Public: false (use service key for uploads)
47
+
48
+ -- Enable Row Level Security (optional but recommended)
49
+ alter table training_runs enable row level security;
50
+ alter table training_episodes enable row level security;
51
+
52
+ -- Allow inserts with service key (anon or service_role)
53
+ create policy "Allow insert training_runs" on training_runs
54
+ for insert with check (true);
55
+ create policy "Allow select training_runs" on training_runs
56
+ for select using (true);
57
+
58
+ create policy "Allow insert training_episodes" on training_episodes
59
+ for insert with check (true);
60
+ create policy "Allow select training_episodes" on training_episodes
61
+ for select using (true);