Claude commited on
Commit
76f180f
·
unverified ·
1 Parent(s): 726152d

Make Supabase uploads incremental — upload after every step

Browse files

Instead of uploading once at the end (risking total data loss on
crash/timeout), the uploader now:
- Creates/upserts the training_runs row after each step
- Inserts episode rows immediately after each step
- Calls finish() at end to update duration and upload files

Also adds a callback hook to TrainingLogger so the uploader is
notified automatically via log_iteration().

Note: requires adding an UPDATE RLS policy on training_runs
(see updated supabase_setup.sql).

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

layer1/train.py CHANGED
@@ -33,7 +33,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
33
  from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
34
  from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
35
  from layer1.training_logger import TrainingLogger, ReportGenerator
36
- from layer1.upload import upload_training_results
37
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
38
  from layer2.hf_agent import HFAgent
39
  from personas.generate_personas import generate_personas
@@ -153,6 +153,24 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
153
  training_logger = TrainingLogger(
154
  log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
157
  trainer.setup_model()
158
  trainer.train()
@@ -214,28 +232,20 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
214
  except OSError:
215
  print("WARNING: Could not re-read report from disk")
216
 
217
- # Upload to Supabase if configured
218
- upload_cfg = upload_cfg or {}
219
- if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
220
  print(f"\n{'='*60}")
221
- print("UPLOADING TO SUPABASE...")
222
  print(f"{'='*60}")
223
- upload_result = upload_training_results(
 
 
224
  raw_summary=raw_summary,
225
- run_id=training_logger.timestamp,
226
- bucket=upload_cfg.get("bucket", "training-results"),
227
- report_path=report_path if report_cfg["enabled"] else None,
228
- chart_path=None, # chart path is internal to ReportGenerator
229
- config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
230
  )
231
- print(f" Run ID: {upload_result['run_id']}")
232
- print(f" Files: {len(upload_result['storage_paths'])} uploaded")
233
- print(f" DB rows: {upload_result['db_rows']}")
234
- if upload_result.get("error"):
235
- print(f" Error: {upload_result['error']}")
236
  print(f"{'='*60}")
237
- elif upload_cfg.get("enabled"):
238
- print("\nSupabase upload enabled but SUPABASE_URL not set — skipping")
239
 
240
 
241
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
 
33
  from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
34
  from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
35
  from layer1.training_logger import TrainingLogger, ReportGenerator
36
+ from layer1.upload import SupabaseUploader
37
  from layer2.customer_sim import CustomerPersona, CustomerSimulator
38
  from layer2.hf_agent import HFAgent
39
  from personas.generate_personas import generate_personas
 
153
  training_logger = TrainingLogger(
154
  log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
155
  )
156
+
157
+ # Wire up incremental Supabase uploads
158
+ upload_cfg = upload_cfg or {}
159
+ uploader = None
160
+ if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
161
+ uploader = SupabaseUploader(
162
+ run_id=training_logger.timestamp,
163
+ bucket=upload_cfg.get("bucket", "training-results"),
164
+ config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
165
+ )
166
+ if uploader.enabled:
167
+ training_logger.add_on_step_callback(uploader.after_step)
168
+ print("Supabase incremental upload enabled")
169
+ else:
170
+ uploader = None
171
+ elif upload_cfg.get("enabled"):
172
+ print("Supabase upload enabled but SUPABASE_URL not set — skipping")
173
+
174
  trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
175
  trainer.setup_model()
176
  trainer.train()
 
232
  except OSError:
233
  print("WARNING: Could not re-read report from disk")
234
 
235
+ # Finalize Supabase upload (update duration, upload files)
236
+ if uploader and uploader.enabled:
 
237
  print(f"\n{'='*60}")
238
+ print("FINALIZING SUPABASE UPLOAD...")
239
  print(f"{'='*60}")
240
+ uploader.finish(
241
+ duration_seconds=raw_summary.get("duration_seconds"),
242
+ report_path=report_path,
243
  raw_summary=raw_summary,
 
 
 
 
 
244
  )
245
+ print(f" Run ID: {uploader.run_id}")
246
+ print(f" Steps uploaded incrementally: {len(uploader._mean_rewards)}")
247
+ print(f" Episodes uploaded: {uploader._total_episodes}")
 
 
248
  print(f"{'='*60}")
 
 
249
 
250
 
251
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
layer1/training_logger.py CHANGED
@@ -32,6 +32,7 @@ class TrainingLogger:
32
  self.total_steps = total_steps
33
  self.iterations: list[dict[str, Any]] = []
34
  self._start_time = datetime.now()
 
35
 
36
  with open(self.log_path, "w") as f:
37
  f.write(f"Training Log — {self._start_time.isoformat()}\n")
@@ -39,6 +40,10 @@ class TrainingLogger:
39
  f.flush()
40
  os.fsync(f.fileno())
41
 
 
 
 
 
42
  def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
43
  """Log a single training iteration (one prompt evaluated)."""
44
  entry = {
@@ -69,6 +74,13 @@ class TrainingLogger:
69
 
70
  logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
71
 
 
 
 
 
 
 
 
72
  def save_json(self):
73
  """Save structured training data to JSON."""
74
  data = {
 
32
  self.total_steps = total_steps
33
  self.iterations: list[dict[str, Any]] = []
34
  self._start_time = datetime.now()
35
+ self._on_step_callbacks: list[Any] = []
36
 
37
  with open(self.log_path, "w") as f:
38
  f.write(f"Training Log — {self._start_time.isoformat()}\n")
 
40
  f.flush()
41
  os.fsync(f.fileno())
42
 
43
+ def add_on_step_callback(self, callback):
44
+ """Register a callback called after each step: callback(step, eval_result, prompt)."""
45
+ self._on_step_callbacks.append(callback)
46
+
47
  def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
48
  """Log a single training iteration (one prompt evaluated)."""
49
  entry = {
 
74
 
75
  logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
76
 
77
+ # Notify callbacks (e.g. Supabase uploader)
78
+ for cb in self._on_step_callbacks:
79
+ try:
80
+ cb(step, eval_result, prompt)
81
+ except Exception as e:
82
+ logger.error("Step callback failed: %s", e)
83
+
84
  def save_json(self):
85
  """Save structured training data to JSON."""
86
  data = {
layer1/upload.py CHANGED
@@ -1,9 +1,11 @@
1
  """
2
- Supabase uploader for training results.
3
 
4
- Uploads:
5
- 1. Raw summary JSON + report files to Supabase Storage
6
- 2. Per-run and per-episode metrics to Postgres tables
 
 
7
 
8
  Requires SUPABASE_URL and SUPABASE_KEY environment variables.
9
  """
@@ -38,163 +40,172 @@ def _get_client():
38
  return create_client(url, key)
39
 
40
 
41
- def upload_training_results(
42
- raw_summary: dict[str, Any],
43
- run_id: str | None = None,
44
- bucket: str = "training-results",
45
- report_path: str | None = None,
46
- chart_path: str | None = None,
47
- config: dict[str, Any] | None = None,
48
- ) -> dict[str, Any]:
49
  """
50
- Upload training results to Supabase (Storage + DB).
51
-
52
- Args:
53
- raw_summary: Output of TrainingLogger.generate_raw_summary().
54
- run_id: Unique run identifier. Auto-generated if not provided.
55
- bucket: Supabase Storage bucket name.
56
- report_path: Path to the markdown report file (optional).
57
- chart_path: Path to the reward chart PNG (optional).
58
- config: Training config dict to store with the run (optional).
59
-
60
- Returns:
61
- Dict with upload results: {"run_id", "storage_paths", "db_rows"}.
62
  """
63
- client = _get_client()
64
- if client is None:
65
- logger.warning("Supabase upload skipped — client not available")
66
- return {"run_id": None, "storage_paths": [], "db_rows": 0, "error": "no client"}
67
-
68
- if run_id is None:
69
- run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
70
-
71
- results: dict[str, Any] = {"run_id": run_id, "storage_paths": [], "db_rows": 0}
72
-
73
- # --- Storage uploads ---
74
- results["storage_paths"] = _upload_files(
75
- client, bucket, run_id, raw_summary, report_path, chart_path
76
- )
77
-
78
- # --- DB inserts ---
79
- results["db_rows"] = _insert_metrics(client, run_id, raw_summary, config)
80
-
81
- logger.info(
82
- "Supabase upload complete: run_id=%s, files=%d, db_rows=%d",
83
- run_id, len(results["storage_paths"]), results["db_rows"],
84
- )
85
- return results
86
-
87
-
88
- def _upload_files(
89
- client,
90
- bucket: str,
91
- run_id: str,
92
- raw_summary: dict[str, Any],
93
- report_path: str | None,
94
- chart_path: str | None,
95
- ) -> list[str]:
96
- """Upload files to Supabase Storage."""
97
- uploaded = []
98
-
99
- # Upload raw summary JSON
100
- try:
101
- summary_bytes = json.dumps(raw_summary, indent=2, default=str).encode()
102
- path = f"{run_id}/raw_summary.json"
103
- client.storage.from_(bucket).upload(
104
- path, summary_bytes, {"content-type": "application/json"}
105
- )
106
- uploaded.append(path)
107
- logger.info("Uploaded %s to storage", path)
108
- except Exception as e:
109
- logger.error("Failed to upload raw_summary.json: %s", e)
110
 
111
- # Upload report markdown
112
- if report_path and os.path.exists(report_path):
113
- try:
114
- with open(report_path, "rb") as f:
115
- path = f"{run_id}/report.md"
116
- client.storage.from_(bucket).upload(
117
- path, f.read(), {"content-type": "text/markdown"}
118
- )
119
- uploaded.append(path)
120
- logger.info("Uploaded %s to storage", path)
121
- except Exception as e:
122
- logger.error("Failed to upload report: %s", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Upload chart PNG
125
- if chart_path and os.path.exists(chart_path):
126
  try:
127
- with open(chart_path, "rb") as f:
128
- path = f"{run_id}/reward_chart.png"
129
- client.storage.from_(bucket).upload(
130
- path, f.read(), {"content-type": "image/png"}
131
- )
132
- uploaded.append(path)
133
- logger.info("Uploaded %s to storage", path)
 
134
  except Exception as e:
135
- logger.error("Failed to upload chart: %s", e)
136
-
137
- return uploaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
 
 
 
 
 
 
139
 
140
- def _insert_metrics(
141
- client,
142
- run_id: str,
143
- raw_summary: dict[str, Any],
144
- config: dict[str, Any] | None,
145
- ) -> int:
146
- """Insert training run + per-episode metrics into Postgres tables."""
147
- rows_inserted = 0
148
 
149
- # Insert training run summary
150
- try:
151
- run_row = {
152
- "run_id": run_id,
153
- "started_at": datetime.now(timezone.utc).isoformat(),
154
- "duration_seconds": raw_summary.get("duration_seconds"),
155
- "total_steps": len(raw_summary.get("steps", [])),
156
- "total_episodes": raw_summary.get("total_episodes", 0),
157
- "best_step": raw_summary.get("best_step"),
158
- "best_mean_reward": raw_summary.get("best_mean_reward"),
159
- "mean_rewards": raw_summary.get("mean_rewards", []),
160
- "min_rewards": raw_summary.get("min_rewards", []),
161
- "max_rewards": raw_summary.get("max_rewards", []),
162
- "config": config,
163
- }
164
- client.table("training_runs").insert(run_row).execute()
165
- rows_inserted += 1
166
- logger.info("Inserted training run: %s", run_id)
167
- except Exception as e:
168
- logger.error("Failed to insert training_runs row: %s", e)
169
-
170
- # Insert per-episode metrics in batches
171
- episode_rows = []
172
- for m in raw_summary.get("per_episode_metrics", []):
173
- episode_rows.append({
174
- "run_id": run_id,
175
- "step": m["step"],
176
- "episode": m["episode"],
177
- "reward": m.get("reward"),
178
- "turns": m.get("turns", 0),
179
- "intent_captured": m.get("intent_captured", False),
180
- "intent_correct": m.get("intent_correct", False),
181
- "true_intent": m.get("true_intent", ""),
182
- "agent_intent": m.get("agent_intent", ""),
183
- "injection_attempted": m.get("injection_attempted", False),
184
- "injection_succeeded": m.get("injection_succeeded", False),
185
- "api_call_made": m.get("api_call_made", False),
186
- "api_call_correct": m.get("api_call_correct", False),
187
- })
188
-
189
- # Batch insert (Supabase/PostgREST supports bulk inserts)
190
- if episode_rows:
191
- batch_size = 100
192
- for i in range(0, len(episode_rows), batch_size):
193
- batch = episode_rows[i : i + batch_size]
194
- try:
195
- client.table("training_episodes").insert(batch).execute()
196
- rows_inserted += len(batch)
197
- except Exception as e:
198
- logger.error("Failed to insert episode batch %d: %s", i, e)
199
 
200
- return rows_inserted
 
 
 
 
 
 
 
 
 
1
  """
2
+ Supabase uploader for training results — incremental mode.
3
 
4
+ Uploads after every training step so data is never lost if the job crashes.
5
+
6
+ - Creates a training_runs row at the start of training
7
+ - Upserts that row after each step with updated reward arrays
8
+ - Inserts per-episode rows after each step
9
 
10
  Requires SUPABASE_URL and SUPABASE_KEY environment variables.
11
  """
 
40
  return create_client(url, key)
41
 
42
 
43
+ class SupabaseUploader:
 
 
 
 
 
 
 
44
  """
45
+ Incremental uploader call after_step() after each training step.
46
+
47
+ Creates the training_runs row on first call, then upserts it with
48
+ updated arrays on every subsequent call. Episode rows are inserted
49
+ immediately and never re-sent.
 
 
 
 
 
 
 
50
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def __init__(
53
+ self,
54
+ run_id: str,
55
+ bucket: str = "training-results",
56
+ config: dict[str, Any] | None = None,
57
+ ):
58
+ self.run_id = run_id
59
+ self.bucket = bucket
60
+ self.config = config
61
+ self._client = _get_client()
62
+ self._run_created = False
63
+
64
+ # Accumulated arrays (mirrors what training_runs stores)
65
+ self._mean_rewards: list[float] = []
66
+ self._min_rewards: list[float] = []
67
+ self._max_rewards: list[float] = []
68
+ self._total_episodes = 0
69
+ self._started_at = datetime.now(timezone.utc).isoformat()
70
+
71
+ if self._client:
72
+ logger.info("SupabaseUploader ready: run_id=%s", run_id)
73
+ else:
74
+ logger.warning("SupabaseUploader: no client — uploads will be skipped")
75
+
76
+ @property
77
+ def enabled(self) -> bool:
78
+ return self._client is not None
79
+
80
+ def after_step(self, step: int, eval_result: dict[str, Any], prompt: str):
81
+ """
82
+ Called after each training step/candidate evaluation.
83
+
84
+ Upserts the training_runs row and inserts new episode rows.
85
+ """
86
+ if not self._client:
87
+ return
88
+
89
+ mean_reward = eval_result.get("mean_reward", 0.0)
90
+ min_reward = eval_result.get("min_reward", 0.0)
91
+ max_reward = eval_result.get("max_reward", 0.0)
92
+
93
+ self._mean_rewards.append(mean_reward)
94
+ self._min_rewards.append(min_reward)
95
+ self._max_rewards.append(max_reward)
96
+
97
+ num_episodes = eval_result.get("num_episodes", 0)
98
+ self._total_episodes += num_episodes
99
+
100
+ # Best so far
101
+ best_mean = max(self._mean_rewards)
102
+ best_idx = self._mean_rewards.index(best_mean)
103
+
104
+ # --- Upsert training_runs row ---
105
+ run_row = {
106
+ "run_id": self.run_id,
107
+ "started_at": self._started_at,
108
+ "duration_seconds": None, # updated at end
109
+ "total_steps": len(self._mean_rewards),
110
+ "total_episodes": self._total_episodes,
111
+ "best_step": best_idx,
112
+ "best_mean_reward": best_mean,
113
+ "mean_rewards": self._mean_rewards,
114
+ "min_rewards": self._min_rewards,
115
+ "max_rewards": self._max_rewards,
116
+ "config": self.config,
117
+ }
118
 
 
 
119
  try:
120
+ self._client.table("training_runs").upsert(
121
+ run_row, on_conflict="run_id"
122
+ ).execute()
123
+ self._run_created = True
124
+ logger.info(
125
+ "Upserted training_runs: step=%d mean_reward=%.1f",
126
+ step, mean_reward,
127
+ )
128
  except Exception as e:
129
+ logger.error("Failed to upsert training_runs: %s", e)
130
+
131
+ # --- Insert episode rows for this step ---
132
+ episode_rows = []
133
+ rewards_list = eval_result.get("rewards", [])
134
+ for ei, log in enumerate(eval_result.get("logs", [])):
135
+ episode_rows.append({
136
+ "run_id": self.run_id,
137
+ "step": step,
138
+ "episode": ei,
139
+ "reward": rewards_list[ei] if ei < len(rewards_list) else None,
140
+ "turns": log.get("turns", 0),
141
+ "intent_captured": log.get("intent_captured", False),
142
+ "intent_correct": log.get("intent_correct", False),
143
+ "true_intent": log.get("true_intent", ""),
144
+ "agent_intent": log.get("agent_intent", ""),
145
+ "injection_attempted": log.get("injection_attempted", False),
146
+ "injection_succeeded": log.get("injection_succeeded", False),
147
+ "api_call_made": log.get("api_call_made", False),
148
+ "api_call_correct": log.get("api_call_correct", False),
149
+ })
150
+
151
+ if episode_rows:
152
+ try:
153
+ self._client.table("training_episodes").insert(episode_rows).execute()
154
+ logger.info(
155
+ "Inserted %d episode rows for step %d", len(episode_rows), step
156
+ )
157
+ except Exception as e:
158
+ logger.error("Failed to insert episodes for step %d: %s", step, e)
159
+
160
+ def finish(
161
+ self,
162
+ duration_seconds: float | None = None,
163
+ report_path: str | None = None,
164
+ chart_path: str | None = None,
165
+ raw_summary: dict[str, Any] | None = None,
166
+ ):
167
+ """
168
+ Called at end of training. Updates duration and uploads final files.
169
+ """
170
+ if not self._client:
171
+ return
172
+
173
+ # Update duration on the run row
174
+ if duration_seconds is not None and self._run_created:
175
+ try:
176
+ self._client.table("training_runs").update(
177
+ {"duration_seconds": duration_seconds}
178
+ ).eq("run_id", self.run_id).execute()
179
+ logger.info("Updated duration: %.1fs", duration_seconds)
180
+ except Exception as e:
181
+ logger.error("Failed to update duration: %s", e)
182
 
183
+ # Upload files to Storage
184
+ if raw_summary:
185
+ self._upload_file(
186
+ f"{self.run_id}/raw_summary.json",
187
+ json.dumps(raw_summary, indent=2, default=str).encode(),
188
+ "application/json",
189
+ )
190
 
191
+ if report_path and os.path.exists(report_path):
192
+ with open(report_path, "rb") as f:
193
+ self._upload_file(
194
+ f"{self.run_id}/report.md", f.read(), "text/markdown"
195
+ )
 
 
 
196
 
197
+ if chart_path and os.path.exists(chart_path):
198
+ with open(chart_path, "rb") as f:
199
+ self._upload_file(
200
+ f"{self.run_id}/reward_chart.png", f.read(), "image/png"
201
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def _upload_file(self, path: str, data: bytes, content_type: str):
204
+ """Upload a single file to Supabase Storage."""
205
+ try:
206
+ self._client.storage.from_(self.bucket).upload(
207
+ path, data, {"content-type": content_type}
208
+ )
209
+ logger.info("Uploaded %s to storage", path)
210
+ except Exception as e:
211
+ logger.error("Failed to upload %s: %s", path, e)
scripts/supabase_setup.sql CHANGED
@@ -49,9 +49,11 @@ create index if not exists idx_episodes_step on training_episodes(run_id, step);
49
  alter table training_runs enable row level security;
50
  alter table training_episodes enable row level security;
51
 
52
- -- Allow inserts with service key (anon or service_role)
53
  create policy "Allow insert training_runs" on training_runs
54
  for insert with check (true);
 
 
55
  create policy "Allow select training_runs" on training_runs
56
  for select using (true);
57
 
 
49
  alter table training_runs enable row level security;
50
  alter table training_episodes enable row level security;
51
 
52
+ -- Allow inserts, updates, and selects with service key (anon or service_role)
53
  create policy "Allow insert training_runs" on training_runs
54
  for insert with check (true);
55
+ create policy "Allow update training_runs" on training_runs
56
+ for update using (true);
57
  create policy "Allow select training_runs" on training_runs
58
  for select using (true);
59