Spaces:
Running on T4
Running on T4
Claude commited on
Make Supabase uploads incremental — upload after every step
Browse filesInstead of uploading once at the end (risking total data loss on
crash/timeout), the uploader now:
- Creates/upserts the training_runs row after each step
- Inserts episode rows immediately after each step
- Calls finish() at end to update duration and upload files
Also adds a callback hook to TrainingLogger so the uploader is
notified automatically via log_iteration().
Note: requires adding an UPDATE RLS policy on training_runs
(see updated supabase_setup.sql).
https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V
- layer1/train.py +28 -18
- layer1/training_logger.py +12 -0
- layer1/upload.py +165 -154
- scripts/supabase_setup.sql +3 -1
layer1/train.py
CHANGED
|
@@ -33,7 +33,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 33 |
from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
|
| 34 |
from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
|
| 35 |
from layer1.training_logger import TrainingLogger, ReportGenerator
|
| 36 |
-
from layer1.upload import
|
| 37 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 38 |
from layer2.hf_agent import HFAgent
|
| 39 |
from personas.generate_personas import generate_personas
|
|
@@ -153,6 +153,24 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
|
|
| 153 |
training_logger = TrainingLogger(
|
| 154 |
log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
|
| 155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
|
| 157 |
trainer.setup_model()
|
| 158 |
trainer.train()
|
|
@@ -214,28 +232,20 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
|
|
| 214 |
except OSError:
|
| 215 |
print("WARNING: Could not re-read report from disk")
|
| 216 |
|
| 217 |
-
#
|
| 218 |
-
|
| 219 |
-
if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
|
| 220 |
print(f"\n{'='*60}")
|
| 221 |
-
print("
|
| 222 |
print(f"{'='*60}")
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
raw_summary=raw_summary,
|
| 225 |
-
run_id=training_logger.timestamp,
|
| 226 |
-
bucket=upload_cfg.get("bucket", "training-results"),
|
| 227 |
-
report_path=report_path if report_cfg["enabled"] else None,
|
| 228 |
-
chart_path=None, # chart path is internal to ReportGenerator
|
| 229 |
-
config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
|
| 230 |
)
|
| 231 |
-
print(f" Run ID: {
|
| 232 |
-
print(f"
|
| 233 |
-
print(f"
|
| 234 |
-
if upload_result.get("error"):
|
| 235 |
-
print(f" Error: {upload_result['error']}")
|
| 236 |
print(f"{'='*60}")
|
| 237 |
-
elif upload_cfg.get("enabled"):
|
| 238 |
-
print("\nSupabase upload enabled but SUPABASE_URL not set — skipping")
|
| 239 |
|
| 240 |
|
| 241 |
def run_eval(hf_token: str | None, prompt: str, episodes: int):
|
|
|
|
| 33 |
from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
|
| 34 |
from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
|
| 35 |
from layer1.training_logger import TrainingLogger, ReportGenerator
|
| 36 |
+
from layer1.upload import SupabaseUploader
|
| 37 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 38 |
from layer2.hf_agent import HFAgent
|
| 39 |
from personas.generate_personas import generate_personas
|
|
|
|
| 153 |
training_logger = TrainingLogger(
|
| 154 |
log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
|
| 155 |
)
|
| 156 |
+
|
| 157 |
+
# Wire up incremental Supabase uploads
|
| 158 |
+
upload_cfg = upload_cfg or {}
|
| 159 |
+
uploader = None
|
| 160 |
+
if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
|
| 161 |
+
uploader = SupabaseUploader(
|
| 162 |
+
run_id=training_logger.timestamp,
|
| 163 |
+
bucket=upload_cfg.get("bucket", "training-results"),
|
| 164 |
+
config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
|
| 165 |
+
)
|
| 166 |
+
if uploader.enabled:
|
| 167 |
+
training_logger.add_on_step_callback(uploader.after_step)
|
| 168 |
+
print("Supabase incremental upload enabled")
|
| 169 |
+
else:
|
| 170 |
+
uploader = None
|
| 171 |
+
elif upload_cfg.get("enabled"):
|
| 172 |
+
print("Supabase upload enabled but SUPABASE_URL not set — skipping")
|
| 173 |
+
|
| 174 |
trainer = GRPOPromptTrainer(config=config, evaluator=evaluator, logger=training_logger)
|
| 175 |
trainer.setup_model()
|
| 176 |
trainer.train()
|
|
|
|
| 232 |
except OSError:
|
| 233 |
print("WARNING: Could not re-read report from disk")
|
| 234 |
|
| 235 |
+
# Finalize Supabase upload (update duration, upload files)
|
| 236 |
+
if uploader and uploader.enabled:
|
|
|
|
| 237 |
print(f"\n{'='*60}")
|
| 238 |
+
print("FINALIZING SUPABASE UPLOAD...")
|
| 239 |
print(f"{'='*60}")
|
| 240 |
+
uploader.finish(
|
| 241 |
+
duration_seconds=raw_summary.get("duration_seconds"),
|
| 242 |
+
report_path=report_path,
|
| 243 |
raw_summary=raw_summary,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
)
|
| 245 |
+
print(f" Run ID: {uploader.run_id}")
|
| 246 |
+
print(f" Steps uploaded incrementally: {len(uploader._mean_rewards)}")
|
| 247 |
+
print(f" Episodes uploaded: {uploader._total_episodes}")
|
|
|
|
|
|
|
| 248 |
print(f"{'='*60}")
|
|
|
|
|
|
|
| 249 |
|
| 250 |
|
| 251 |
def run_eval(hf_token: str | None, prompt: str, episodes: int):
|
layer1/training_logger.py
CHANGED
|
@@ -32,6 +32,7 @@ class TrainingLogger:
|
|
| 32 |
self.total_steps = total_steps
|
| 33 |
self.iterations: list[dict[str, Any]] = []
|
| 34 |
self._start_time = datetime.now()
|
|
|
|
| 35 |
|
| 36 |
with open(self.log_path, "w") as f:
|
| 37 |
f.write(f"Training Log — {self._start_time.isoformat()}\n")
|
|
@@ -39,6 +40,10 @@ class TrainingLogger:
|
|
| 39 |
f.flush()
|
| 40 |
os.fsync(f.fileno())
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
|
| 43 |
"""Log a single training iteration (one prompt evaluated)."""
|
| 44 |
entry = {
|
|
@@ -69,6 +74,13 @@ class TrainingLogger:
|
|
| 69 |
|
| 70 |
logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def save_json(self):
|
| 73 |
"""Save structured training data to JSON."""
|
| 74 |
data = {
|
|
|
|
| 32 |
self.total_steps = total_steps
|
| 33 |
self.iterations: list[dict[str, Any]] = []
|
| 34 |
self._start_time = datetime.now()
|
| 35 |
+
self._on_step_callbacks: list[Any] = []
|
| 36 |
|
| 37 |
with open(self.log_path, "w") as f:
|
| 38 |
f.write(f"Training Log — {self._start_time.isoformat()}\n")
|
|
|
|
| 40 |
f.flush()
|
| 41 |
os.fsync(f.fileno())
|
| 42 |
|
| 43 |
+
def add_on_step_callback(self, callback):
|
| 44 |
+
"""Register a callback called after each step: callback(step, eval_result, prompt)."""
|
| 45 |
+
self._on_step_callbacks.append(callback)
|
| 46 |
+
|
| 47 |
def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
|
| 48 |
"""Log a single training iteration (one prompt evaluated)."""
|
| 49 |
entry = {
|
|
|
|
| 74 |
|
| 75 |
logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
|
| 76 |
|
| 77 |
+
# Notify callbacks (e.g. Supabase uploader)
|
| 78 |
+
for cb in self._on_step_callbacks:
|
| 79 |
+
try:
|
| 80 |
+
cb(step, eval_result, prompt)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error("Step callback failed: %s", e)
|
| 83 |
+
|
| 84 |
def save_json(self):
|
| 85 |
"""Save structured training data to JSON."""
|
| 86 |
data = {
|
layer1/upload.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
-
Supabase uploader for training results.
|
| 3 |
|
| 4 |
-
Uploads
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
|
| 8 |
Requires SUPABASE_URL and SUPABASE_KEY environment variables.
|
| 9 |
"""
|
|
@@ -38,163 +40,172 @@ def _get_client():
|
|
| 38 |
return create_client(url, key)
|
| 39 |
|
| 40 |
|
| 41 |
-
|
| 42 |
-
raw_summary: dict[str, Any],
|
| 43 |
-
run_id: str | None = None,
|
| 44 |
-
bucket: str = "training-results",
|
| 45 |
-
report_path: str | None = None,
|
| 46 |
-
chart_path: str | None = None,
|
| 47 |
-
config: dict[str, Any] | None = None,
|
| 48 |
-
) -> dict[str, Any]:
|
| 49 |
"""
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
bucket: Supabase Storage bucket name.
|
| 56 |
-
report_path: Path to the markdown report file (optional).
|
| 57 |
-
chart_path: Path to the reward chart PNG (optional).
|
| 58 |
-
config: Training config dict to store with the run (optional).
|
| 59 |
-
|
| 60 |
-
Returns:
|
| 61 |
-
Dict with upload results: {"run_id", "storage_paths", "db_rows"}.
|
| 62 |
"""
|
| 63 |
-
client = _get_client()
|
| 64 |
-
if client is None:
|
| 65 |
-
logger.warning("Supabase upload skipped — client not available")
|
| 66 |
-
return {"run_id": None, "storage_paths": [], "db_rows": 0, "error": "no client"}
|
| 67 |
-
|
| 68 |
-
if run_id is None:
|
| 69 |
-
run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 70 |
-
|
| 71 |
-
results: dict[str, Any] = {"run_id": run_id, "storage_paths": [], "db_rows": 0}
|
| 72 |
-
|
| 73 |
-
# --- Storage uploads ---
|
| 74 |
-
results["storage_paths"] = _upload_files(
|
| 75 |
-
client, bucket, run_id, raw_summary, report_path, chart_path
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# --- DB inserts ---
|
| 79 |
-
results["db_rows"] = _insert_metrics(client, run_id, raw_summary, config)
|
| 80 |
-
|
| 81 |
-
logger.info(
|
| 82 |
-
"Supabase upload complete: run_id=%s, files=%d, db_rows=%d",
|
| 83 |
-
run_id, len(results["storage_paths"]), results["db_rows"],
|
| 84 |
-
)
|
| 85 |
-
return results
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def _upload_files(
|
| 89 |
-
client,
|
| 90 |
-
bucket: str,
|
| 91 |
-
run_id: str,
|
| 92 |
-
raw_summary: dict[str, Any],
|
| 93 |
-
report_path: str | None,
|
| 94 |
-
chart_path: str | None,
|
| 95 |
-
) -> list[str]:
|
| 96 |
-
"""Upload files to Supabase Storage."""
|
| 97 |
-
uploaded = []
|
| 98 |
-
|
| 99 |
-
# Upload raw summary JSON
|
| 100 |
-
try:
|
| 101 |
-
summary_bytes = json.dumps(raw_summary, indent=2, default=str).encode()
|
| 102 |
-
path = f"{run_id}/raw_summary.json"
|
| 103 |
-
client.storage.from_(bucket).upload(
|
| 104 |
-
path, summary_bytes, {"content-type": "application/json"}
|
| 105 |
-
)
|
| 106 |
-
uploaded.append(path)
|
| 107 |
-
logger.info("Uploaded %s to storage", path)
|
| 108 |
-
except Exception as e:
|
| 109 |
-
logger.error("Failed to upload raw_summary.json: %s", e)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
# Upload chart PNG
|
| 125 |
-
if chart_path and os.path.exists(chart_path):
|
| 126 |
try:
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
| 134 |
except Exception as e:
|
| 135 |
-
logger.error("Failed to
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
) -> int:
|
| 146 |
-
"""Insert training run + per-episode metrics into Postgres tables."""
|
| 147 |
-
rows_inserted = 0
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
"duration_seconds": raw_summary.get("duration_seconds"),
|
| 155 |
-
"total_steps": len(raw_summary.get("steps", [])),
|
| 156 |
-
"total_episodes": raw_summary.get("total_episodes", 0),
|
| 157 |
-
"best_step": raw_summary.get("best_step"),
|
| 158 |
-
"best_mean_reward": raw_summary.get("best_mean_reward"),
|
| 159 |
-
"mean_rewards": raw_summary.get("mean_rewards", []),
|
| 160 |
-
"min_rewards": raw_summary.get("min_rewards", []),
|
| 161 |
-
"max_rewards": raw_summary.get("max_rewards", []),
|
| 162 |
-
"config": config,
|
| 163 |
-
}
|
| 164 |
-
client.table("training_runs").insert(run_row).execute()
|
| 165 |
-
rows_inserted += 1
|
| 166 |
-
logger.info("Inserted training run: %s", run_id)
|
| 167 |
-
except Exception as e:
|
| 168 |
-
logger.error("Failed to insert training_runs row: %s", e)
|
| 169 |
-
|
| 170 |
-
# Insert per-episode metrics in batches
|
| 171 |
-
episode_rows = []
|
| 172 |
-
for m in raw_summary.get("per_episode_metrics", []):
|
| 173 |
-
episode_rows.append({
|
| 174 |
-
"run_id": run_id,
|
| 175 |
-
"step": m["step"],
|
| 176 |
-
"episode": m["episode"],
|
| 177 |
-
"reward": m.get("reward"),
|
| 178 |
-
"turns": m.get("turns", 0),
|
| 179 |
-
"intent_captured": m.get("intent_captured", False),
|
| 180 |
-
"intent_correct": m.get("intent_correct", False),
|
| 181 |
-
"true_intent": m.get("true_intent", ""),
|
| 182 |
-
"agent_intent": m.get("agent_intent", ""),
|
| 183 |
-
"injection_attempted": m.get("injection_attempted", False),
|
| 184 |
-
"injection_succeeded": m.get("injection_succeeded", False),
|
| 185 |
-
"api_call_made": m.get("api_call_made", False),
|
| 186 |
-
"api_call_correct": m.get("api_call_correct", False),
|
| 187 |
-
})
|
| 188 |
-
|
| 189 |
-
# Batch insert (Supabase/PostgREST supports bulk inserts)
|
| 190 |
-
if episode_rows:
|
| 191 |
-
batch_size = 100
|
| 192 |
-
for i in range(0, len(episode_rows), batch_size):
|
| 193 |
-
batch = episode_rows[i : i + batch_size]
|
| 194 |
-
try:
|
| 195 |
-
client.table("training_episodes").insert(batch).execute()
|
| 196 |
-
rows_inserted += len(batch)
|
| 197 |
-
except Exception as e:
|
| 198 |
-
logger.error("Failed to insert episode batch %d: %s", i, e)
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Supabase uploader for training results — incremental mode.
|
| 3 |
|
| 4 |
+
Uploads after every training step so data is never lost if the job crashes.
|
| 5 |
+
|
| 6 |
+
- Creates a training_runs row at the start of training
|
| 7 |
+
- Upserts that row after each step with updated reward arrays
|
| 8 |
+
- Inserts per-episode rows after each step
|
| 9 |
|
| 10 |
Requires SUPABASE_URL and SUPABASE_KEY environment variables.
|
| 11 |
"""
|
|
|
|
| 40 |
return create_client(url, key)
|
| 41 |
|
| 42 |
|
| 43 |
+
class SupabaseUploader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
+
Incremental uploader — call after_step() after each training step.
|
| 46 |
+
|
| 47 |
+
Creates the training_runs row on first call, then upserts it with
|
| 48 |
+
updated arrays on every subsequent call. Episode rows are inserted
|
| 49 |
+
immediately and never re-sent.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
run_id: str,
|
| 55 |
+
bucket: str = "training-results",
|
| 56 |
+
config: dict[str, Any] | None = None,
|
| 57 |
+
):
|
| 58 |
+
self.run_id = run_id
|
| 59 |
+
self.bucket = bucket
|
| 60 |
+
self.config = config
|
| 61 |
+
self._client = _get_client()
|
| 62 |
+
self._run_created = False
|
| 63 |
+
|
| 64 |
+
# Accumulated arrays (mirrors what training_runs stores)
|
| 65 |
+
self._mean_rewards: list[float] = []
|
| 66 |
+
self._min_rewards: list[float] = []
|
| 67 |
+
self._max_rewards: list[float] = []
|
| 68 |
+
self._total_episodes = 0
|
| 69 |
+
self._started_at = datetime.now(timezone.utc).isoformat()
|
| 70 |
+
|
| 71 |
+
if self._client:
|
| 72 |
+
logger.info("SupabaseUploader ready: run_id=%s", run_id)
|
| 73 |
+
else:
|
| 74 |
+
logger.warning("SupabaseUploader: no client — uploads will be skipped")
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def enabled(self) -> bool:
|
| 78 |
+
return self._client is not None
|
| 79 |
+
|
| 80 |
+
def after_step(self, step: int, eval_result: dict[str, Any], prompt: str):
|
| 81 |
+
"""
|
| 82 |
+
Called after each training step/candidate evaluation.
|
| 83 |
+
|
| 84 |
+
Upserts the training_runs row and inserts new episode rows.
|
| 85 |
+
"""
|
| 86 |
+
if not self._client:
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
mean_reward = eval_result.get("mean_reward", 0.0)
|
| 90 |
+
min_reward = eval_result.get("min_reward", 0.0)
|
| 91 |
+
max_reward = eval_result.get("max_reward", 0.0)
|
| 92 |
+
|
| 93 |
+
self._mean_rewards.append(mean_reward)
|
| 94 |
+
self._min_rewards.append(min_reward)
|
| 95 |
+
self._max_rewards.append(max_reward)
|
| 96 |
+
|
| 97 |
+
num_episodes = eval_result.get("num_episodes", 0)
|
| 98 |
+
self._total_episodes += num_episodes
|
| 99 |
+
|
| 100 |
+
# Best so far
|
| 101 |
+
best_mean = max(self._mean_rewards)
|
| 102 |
+
best_idx = self._mean_rewards.index(best_mean)
|
| 103 |
+
|
| 104 |
+
# --- Upsert training_runs row ---
|
| 105 |
+
run_row = {
|
| 106 |
+
"run_id": self.run_id,
|
| 107 |
+
"started_at": self._started_at,
|
| 108 |
+
"duration_seconds": None, # updated at end
|
| 109 |
+
"total_steps": len(self._mean_rewards),
|
| 110 |
+
"total_episodes": self._total_episodes,
|
| 111 |
+
"best_step": best_idx,
|
| 112 |
+
"best_mean_reward": best_mean,
|
| 113 |
+
"mean_rewards": self._mean_rewards,
|
| 114 |
+
"min_rewards": self._min_rewards,
|
| 115 |
+
"max_rewards": self._max_rewards,
|
| 116 |
+
"config": self.config,
|
| 117 |
+
}
|
| 118 |
|
|
|
|
|
|
|
| 119 |
try:
|
| 120 |
+
self._client.table("training_runs").upsert(
|
| 121 |
+
run_row, on_conflict="run_id"
|
| 122 |
+
).execute()
|
| 123 |
+
self._run_created = True
|
| 124 |
+
logger.info(
|
| 125 |
+
"Upserted training_runs: step=%d mean_reward=%.1f",
|
| 126 |
+
step, mean_reward,
|
| 127 |
+
)
|
| 128 |
except Exception as e:
|
| 129 |
+
logger.error("Failed to upsert training_runs: %s", e)
|
| 130 |
+
|
| 131 |
+
# --- Insert episode rows for this step ---
|
| 132 |
+
episode_rows = []
|
| 133 |
+
rewards_list = eval_result.get("rewards", [])
|
| 134 |
+
for ei, log in enumerate(eval_result.get("logs", [])):
|
| 135 |
+
episode_rows.append({
|
| 136 |
+
"run_id": self.run_id,
|
| 137 |
+
"step": step,
|
| 138 |
+
"episode": ei,
|
| 139 |
+
"reward": rewards_list[ei] if ei < len(rewards_list) else None,
|
| 140 |
+
"turns": log.get("turns", 0),
|
| 141 |
+
"intent_captured": log.get("intent_captured", False),
|
| 142 |
+
"intent_correct": log.get("intent_correct", False),
|
| 143 |
+
"true_intent": log.get("true_intent", ""),
|
| 144 |
+
"agent_intent": log.get("agent_intent", ""),
|
| 145 |
+
"injection_attempted": log.get("injection_attempted", False),
|
| 146 |
+
"injection_succeeded": log.get("injection_succeeded", False),
|
| 147 |
+
"api_call_made": log.get("api_call_made", False),
|
| 148 |
+
"api_call_correct": log.get("api_call_correct", False),
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
if episode_rows:
|
| 152 |
+
try:
|
| 153 |
+
self._client.table("training_episodes").insert(episode_rows).execute()
|
| 154 |
+
logger.info(
|
| 155 |
+
"Inserted %d episode rows for step %d", len(episode_rows), step
|
| 156 |
+
)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error("Failed to insert episodes for step %d: %s", step, e)
|
| 159 |
+
|
| 160 |
+
def finish(
|
| 161 |
+
self,
|
| 162 |
+
duration_seconds: float | None = None,
|
| 163 |
+
report_path: str | None = None,
|
| 164 |
+
chart_path: str | None = None,
|
| 165 |
+
raw_summary: dict[str, Any] | None = None,
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
Called at end of training. Updates duration and uploads final files.
|
| 169 |
+
"""
|
| 170 |
+
if not self._client:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
# Update duration on the run row
|
| 174 |
+
if duration_seconds is not None and self._run_created:
|
| 175 |
+
try:
|
| 176 |
+
self._client.table("training_runs").update(
|
| 177 |
+
{"duration_seconds": duration_seconds}
|
| 178 |
+
).eq("run_id", self.run_id).execute()
|
| 179 |
+
logger.info("Updated duration: %.1fs", duration_seconds)
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.error("Failed to update duration: %s", e)
|
| 182 |
|
| 183 |
+
# Upload files to Storage
|
| 184 |
+
if raw_summary:
|
| 185 |
+
self._upload_file(
|
| 186 |
+
f"{self.run_id}/raw_summary.json",
|
| 187 |
+
json.dumps(raw_summary, indent=2, default=str).encode(),
|
| 188 |
+
"application/json",
|
| 189 |
+
)
|
| 190 |
|
| 191 |
+
if report_path and os.path.exists(report_path):
|
| 192 |
+
with open(report_path, "rb") as f:
|
| 193 |
+
self._upload_file(
|
| 194 |
+
f"{self.run_id}/report.md", f.read(), "text/markdown"
|
| 195 |
+
)
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
if chart_path and os.path.exists(chart_path):
|
| 198 |
+
with open(chart_path, "rb") as f:
|
| 199 |
+
self._upload_file(
|
| 200 |
+
f"{self.run_id}/reward_chart.png", f.read(), "image/png"
|
| 201 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
def _upload_file(self, path: str, data: bytes, content_type: str):
|
| 204 |
+
"""Upload a single file to Supabase Storage."""
|
| 205 |
+
try:
|
| 206 |
+
self._client.storage.from_(self.bucket).upload(
|
| 207 |
+
path, data, {"content-type": content_type}
|
| 208 |
+
)
|
| 209 |
+
logger.info("Uploaded %s to storage", path)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error("Failed to upload %s: %s", path, e)
|
scripts/supabase_setup.sql
CHANGED
|
@@ -49,9 +49,11 @@ create index if not exists idx_episodes_step on training_episodes(run_id, step);
|
|
| 49 |
alter table training_runs enable row level security;
|
| 50 |
alter table training_episodes enable row level security;
|
| 51 |
|
| 52 |
-
-- Allow inserts with service key (anon or service_role)
|
| 53 |
create policy "Allow insert training_runs" on training_runs
|
| 54 |
for insert with check (true);
|
|
|
|
|
|
|
| 55 |
create policy "Allow select training_runs" on training_runs
|
| 56 |
for select using (true);
|
| 57 |
|
|
|
|
| 49 |
alter table training_runs enable row level security;
|
| 50 |
alter table training_episodes enable row level security;
|
| 51 |
|
| 52 |
+
-- Allow inserts, updates, and selects with service key (anon or service_role)
|
| 53 |
create policy "Allow insert training_runs" on training_runs
|
| 54 |
for insert with check (true);
|
| 55 |
+
create policy "Allow update training_runs" on training_runs
|
| 56 |
+
for update using (true);
|
| 57 |
create policy "Allow select training_runs" on training_runs
|
| 58 |
for select using (true);
|
| 59 |
|