Spaces:
Running on T4
Running on T4
Claude commited on
Add Supabase upload for training results (Storage + DB)
Browse files- layer1/upload.py: uploads raw summary JSON, report, and chart to
Supabase Storage; inserts per-run and per-episode metrics into
training_runs and training_episodes Postgres tables
- scripts/supabase_setup.sql: migration to create tables, indexes,
and RLS policies — run in Supabase SQL Editor before first training
- config.yaml: upload section with enabled flag and bucket name
- config_loader.py: get_upload_config() for new section
- pyproject.toml: supabase>=2.0.0 as optional [upload] dependency
Requires SUPABASE_URL and SUPABASE_KEY env vars. Gracefully skips
upload if not configured (logs warning, training still completes).
https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V
- config.yaml +9 -0
- config_loader.py +9 -0
- layer1/train.py +29 -3
- layer1/upload.py +200 -0
- pyproject.toml +3 -0
- scripts/supabase_setup.sql +61 -0
config.yaml
CHANGED
|
@@ -97,6 +97,15 @@ report:
|
|
| 97 |
example_customers: 5 # Example conversations in report
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# --- Paths ---
|
| 101 |
|
| 102 |
paths:
|
|
|
|
| 97 |
example_customers: 5 # Example conversations in report
|
| 98 |
|
| 99 |
|
| 100 |
+
# --- Upload: Supabase ---
|
| 101 |
+
# Upload training results to Supabase for analysis.
|
| 102 |
+
# Requires SUPABASE_URL and SUPABASE_KEY environment variables.
|
| 103 |
+
|
| 104 |
+
upload:
|
| 105 |
+
enabled: true
|
| 106 |
+
bucket: "training-results" # Supabase Storage bucket name
|
| 107 |
+
|
| 108 |
+
|
| 109 |
# --- Paths ---
|
| 110 |
|
| 111 |
paths:
|
config_loader.py
CHANGED
|
@@ -126,6 +126,15 @@ def get_generation_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
|
| 126 |
}
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
| 130 |
"""Extract persona settings from config."""
|
| 131 |
personas = cfg.get("personas", {})
|
|
|
|
| 126 |
}
|
| 127 |
|
| 128 |
|
| 129 |
+
def get_upload_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
| 130 |
+
"""Extract Supabase upload settings from config."""
|
| 131 |
+
upload = cfg.get("upload", {})
|
| 132 |
+
return {
|
| 133 |
+
"enabled": upload.get("enabled", False),
|
| 134 |
+
"bucket": upload.get("bucket", "training-results"),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
def get_personas_config(cfg: dict[str, Any]) -> dict[str, Any]:
|
| 139 |
"""Extract persona settings from config."""
|
| 140 |
personas = cfg.get("personas", {})
|
layer1/train.py
CHANGED
|
@@ -30,9 +30,10 @@ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file_
|
|
| 30 |
|
| 31 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 32 |
|
| 33 |
-
from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config
|
| 34 |
from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
|
| 35 |
from layer1.training_logger import TrainingLogger, ReportGenerator
|
|
|
|
| 36 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 37 |
from layer2.hf_agent import HFAgent
|
| 38 |
from personas.generate_personas import generate_personas
|
|
@@ -138,7 +139,7 @@ def _print_config_banner(config: GRPOConfig, report_cfg: dict, paths_cfg: dict):
|
|
| 138 |
print(f"{'='*70}\n")
|
| 139 |
|
| 140 |
|
| 141 |
-
def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None):
|
| 142 |
"""Run GRPO training."""
|
| 143 |
_print_config_banner(config, report_cfg, paths_cfg)
|
| 144 |
|
|
@@ -188,6 +189,7 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
|
|
| 188 |
print(f"\nFull raw JSON: {summary_path}")
|
| 189 |
print(f"{'='*60}")
|
| 190 |
|
|
|
|
| 191 |
if report_cfg["enabled"]:
|
| 192 |
print(f"\n{'='*60}")
|
| 193 |
print("GENERATING TRAINING REPORT...")
|
|
@@ -212,6 +214,29 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
|
|
| 212 |
except OSError:
|
| 213 |
print("WARNING: Could not re-read report from disk")
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
def run_eval(hf_token: str | None, prompt: str, episodes: int):
|
| 217 |
"""Evaluate a single prompt."""
|
|
@@ -268,6 +293,7 @@ def main():
|
|
| 268 |
paths_cfg = get_paths(cfg)
|
| 269 |
gen_cfg = get_generation_config(cfg)
|
| 270 |
personas_cfg = get_personas_config(cfg)
|
|
|
|
| 271 |
|
| 272 |
# CLI overrides
|
| 273 |
if args.steps is not None:
|
|
@@ -289,7 +315,7 @@ def main():
|
|
| 289 |
report_cfg["example_customers"] = args.example_customers
|
| 290 |
|
| 291 |
if args.mode == "train":
|
| 292 |
-
run_train(grpo_config, report_cfg, paths_cfg, args.hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg)
|
| 293 |
elif args.mode == "eval":
|
| 294 |
if not args.prompt:
|
| 295 |
parser.error("--prompt is required for eval mode")
|
|
|
|
| 30 |
|
| 31 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 32 |
|
| 33 |
+
from config_loader import load_config, make_grpo_config, make_env_config, get_report_config, get_paths, get_generation_config, get_personas_config, get_upload_config
|
| 34 |
from layer1.grpo_trainer import GRPOConfig, GRPOPromptTrainer, PromptEvaluator
|
| 35 |
from layer1.training_logger import TrainingLogger, ReportGenerator
|
| 36 |
+
from layer1.upload import upload_training_results
|
| 37 |
from layer2.customer_sim import CustomerPersona, CustomerSimulator
|
| 38 |
from layer2.hf_agent import HFAgent
|
| 39 |
from personas.generate_personas import generate_personas
|
|
|
|
| 139 |
print(f"{'='*70}\n")
|
| 140 |
|
| 141 |
|
| 142 |
+
def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None, upload_cfg: dict | None = None):
|
| 143 |
"""Run GRPO training."""
|
| 144 |
_print_config_banner(config, report_cfg, paths_cfg)
|
| 145 |
|
|
|
|
| 189 |
print(f"\nFull raw JSON: {summary_path}")
|
| 190 |
print(f"{'='*60}")
|
| 191 |
|
| 192 |
+
report_path = None
|
| 193 |
if report_cfg["enabled"]:
|
| 194 |
print(f"\n{'='*60}")
|
| 195 |
print("GENERATING TRAINING REPORT...")
|
|
|
|
| 214 |
except OSError:
|
| 215 |
print("WARNING: Could not re-read report from disk")
|
| 216 |
|
| 217 |
+
# Upload to Supabase if configured
|
| 218 |
+
upload_cfg = upload_cfg or {}
|
| 219 |
+
if upload_cfg.get("enabled") and os.environ.get("SUPABASE_URL"):
|
| 220 |
+
print(f"\n{'='*60}")
|
| 221 |
+
print("UPLOADING TO SUPABASE...")
|
| 222 |
+
print(f"{'='*60}")
|
| 223 |
+
upload_result = upload_training_results(
|
| 224 |
+
raw_summary=raw_summary,
|
| 225 |
+
run_id=training_logger.timestamp,
|
| 226 |
+
bucket=upload_cfg.get("bucket", "training-results"),
|
| 227 |
+
report_path=report_path if report_cfg["enabled"] else None,
|
| 228 |
+
chart_path=None, # chart path is internal to ReportGenerator
|
| 229 |
+
config={"grpo": config.__dict__, "report": report_cfg, "paths": paths_cfg},
|
| 230 |
+
)
|
| 231 |
+
print(f" Run ID: {upload_result['run_id']}")
|
| 232 |
+
print(f" Files: {len(upload_result['storage_paths'])} uploaded")
|
| 233 |
+
print(f" DB rows: {upload_result['db_rows']}")
|
| 234 |
+
if upload_result.get("error"):
|
| 235 |
+
print(f" Error: {upload_result['error']}")
|
| 236 |
+
print(f"{'='*60}")
|
| 237 |
+
elif upload_cfg.get("enabled"):
|
| 238 |
+
print("\nSupabase upload enabled but SUPABASE_URL not set — skipping")
|
| 239 |
+
|
| 240 |
|
| 241 |
def run_eval(hf_token: str | None, prompt: str, episodes: int):
|
| 242 |
"""Evaluate a single prompt."""
|
|
|
|
| 293 |
paths_cfg = get_paths(cfg)
|
| 294 |
gen_cfg = get_generation_config(cfg)
|
| 295 |
personas_cfg = get_personas_config(cfg)
|
| 296 |
+
upload_cfg = get_upload_config(cfg)
|
| 297 |
|
| 298 |
# CLI overrides
|
| 299 |
if args.steps is not None:
|
|
|
|
| 315 |
report_cfg["example_customers"] = args.example_customers
|
| 316 |
|
| 317 |
if args.mode == "train":
|
| 318 |
+
run_train(grpo_config, report_cfg, paths_cfg, args.hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg, upload_cfg=upload_cfg)
|
| 319 |
elif args.mode == "eval":
|
| 320 |
if not args.prompt:
|
| 321 |
parser.error("--prompt is required for eval mode")
|
layer1/upload.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supabase uploader for training results.
|
| 3 |
+
|
| 4 |
+
Uploads:
|
| 5 |
+
1. Raw summary JSON + report files to Supabase Storage
|
| 6 |
+
2. Per-run and per-episode metrics to Postgres tables
|
| 7 |
+
|
| 8 |
+
Requires SUPABASE_URL and SUPABASE_KEY environment variables.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from datetime import datetime, timezone
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_client():
|
| 23 |
+
"""Create a Supabase client from environment variables."""
|
| 24 |
+
try:
|
| 25 |
+
from supabase import create_client
|
| 26 |
+
except ImportError:
|
| 27 |
+
logger.error(
|
| 28 |
+
"supabase package not installed. Install with: pip install 'nested-rl-envs[upload]'"
|
| 29 |
+
)
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
url = os.environ.get("SUPABASE_URL")
|
| 33 |
+
key = os.environ.get("SUPABASE_KEY")
|
| 34 |
+
if not url or not key:
|
| 35 |
+
logger.error("SUPABASE_URL and SUPABASE_KEY must be set")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
return create_client(url, key)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def upload_training_results(
|
| 42 |
+
raw_summary: dict[str, Any],
|
| 43 |
+
run_id: str | None = None,
|
| 44 |
+
bucket: str = "training-results",
|
| 45 |
+
report_path: str | None = None,
|
| 46 |
+
chart_path: str | None = None,
|
| 47 |
+
config: dict[str, Any] | None = None,
|
| 48 |
+
) -> dict[str, Any]:
|
| 49 |
+
"""
|
| 50 |
+
Upload training results to Supabase (Storage + DB).
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
raw_summary: Output of TrainingLogger.generate_raw_summary().
|
| 54 |
+
run_id: Unique run identifier. Auto-generated if not provided.
|
| 55 |
+
bucket: Supabase Storage bucket name.
|
| 56 |
+
report_path: Path to the markdown report file (optional).
|
| 57 |
+
chart_path: Path to the reward chart PNG (optional).
|
| 58 |
+
config: Training config dict to store with the run (optional).
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dict with upload results: {"run_id", "storage_paths", "db_rows"}.
|
| 62 |
+
"""
|
| 63 |
+
client = _get_client()
|
| 64 |
+
if client is None:
|
| 65 |
+
logger.warning("Supabase upload skipped — client not available")
|
| 66 |
+
return {"run_id": None, "storage_paths": [], "db_rows": 0, "error": "no client"}
|
| 67 |
+
|
| 68 |
+
if run_id is None:
|
| 69 |
+
run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 70 |
+
|
| 71 |
+
results: dict[str, Any] = {"run_id": run_id, "storage_paths": [], "db_rows": 0}
|
| 72 |
+
|
| 73 |
+
# --- Storage uploads ---
|
| 74 |
+
results["storage_paths"] = _upload_files(
|
| 75 |
+
client, bucket, run_id, raw_summary, report_path, chart_path
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# --- DB inserts ---
|
| 79 |
+
results["db_rows"] = _insert_metrics(client, run_id, raw_summary, config)
|
| 80 |
+
|
| 81 |
+
logger.info(
|
| 82 |
+
"Supabase upload complete: run_id=%s, files=%d, db_rows=%d",
|
| 83 |
+
run_id, len(results["storage_paths"]), results["db_rows"],
|
| 84 |
+
)
|
| 85 |
+
return results
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _upload_files(
|
| 89 |
+
client,
|
| 90 |
+
bucket: str,
|
| 91 |
+
run_id: str,
|
| 92 |
+
raw_summary: dict[str, Any],
|
| 93 |
+
report_path: str | None,
|
| 94 |
+
chart_path: str | None,
|
| 95 |
+
) -> list[str]:
|
| 96 |
+
"""Upload files to Supabase Storage."""
|
| 97 |
+
uploaded = []
|
| 98 |
+
|
| 99 |
+
# Upload raw summary JSON
|
| 100 |
+
try:
|
| 101 |
+
summary_bytes = json.dumps(raw_summary, indent=2, default=str).encode()
|
| 102 |
+
path = f"{run_id}/raw_summary.json"
|
| 103 |
+
client.storage.from_(bucket).upload(
|
| 104 |
+
path, summary_bytes, {"content-type": "application/json"}
|
| 105 |
+
)
|
| 106 |
+
uploaded.append(path)
|
| 107 |
+
logger.info("Uploaded %s to storage", path)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error("Failed to upload raw_summary.json: %s", e)
|
| 110 |
+
|
| 111 |
+
# Upload report markdown
|
| 112 |
+
if report_path and os.path.exists(report_path):
|
| 113 |
+
try:
|
| 114 |
+
with open(report_path, "rb") as f:
|
| 115 |
+
path = f"{run_id}/report.md"
|
| 116 |
+
client.storage.from_(bucket).upload(
|
| 117 |
+
path, f.read(), {"content-type": "text/markdown"}
|
| 118 |
+
)
|
| 119 |
+
uploaded.append(path)
|
| 120 |
+
logger.info("Uploaded %s to storage", path)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error("Failed to upload report: %s", e)
|
| 123 |
+
|
| 124 |
+
# Upload chart PNG
|
| 125 |
+
if chart_path and os.path.exists(chart_path):
|
| 126 |
+
try:
|
| 127 |
+
with open(chart_path, "rb") as f:
|
| 128 |
+
path = f"{run_id}/reward_chart.png"
|
| 129 |
+
client.storage.from_(bucket).upload(
|
| 130 |
+
path, f.read(), {"content-type": "image/png"}
|
| 131 |
+
)
|
| 132 |
+
uploaded.append(path)
|
| 133 |
+
logger.info("Uploaded %s to storage", path)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error("Failed to upload chart: %s", e)
|
| 136 |
+
|
| 137 |
+
return uploaded
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _insert_metrics(
|
| 141 |
+
client,
|
| 142 |
+
run_id: str,
|
| 143 |
+
raw_summary: dict[str, Any],
|
| 144 |
+
config: dict[str, Any] | None,
|
| 145 |
+
) -> int:
|
| 146 |
+
"""Insert training run + per-episode metrics into Postgres tables."""
|
| 147 |
+
rows_inserted = 0
|
| 148 |
+
|
| 149 |
+
# Insert training run summary
|
| 150 |
+
try:
|
| 151 |
+
run_row = {
|
| 152 |
+
"run_id": run_id,
|
| 153 |
+
"started_at": datetime.now(timezone.utc).isoformat(),
|
| 154 |
+
"duration_seconds": raw_summary.get("duration_seconds"),
|
| 155 |
+
"total_steps": len(raw_summary.get("steps", [])),
|
| 156 |
+
"total_episodes": raw_summary.get("total_episodes", 0),
|
| 157 |
+
"best_step": raw_summary.get("best_step"),
|
| 158 |
+
"best_mean_reward": raw_summary.get("best_mean_reward"),
|
| 159 |
+
"mean_rewards": raw_summary.get("mean_rewards", []),
|
| 160 |
+
"min_rewards": raw_summary.get("min_rewards", []),
|
| 161 |
+
"max_rewards": raw_summary.get("max_rewards", []),
|
| 162 |
+
"config": config,
|
| 163 |
+
}
|
| 164 |
+
client.table("training_runs").insert(run_row).execute()
|
| 165 |
+
rows_inserted += 1
|
| 166 |
+
logger.info("Inserted training run: %s", run_id)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error("Failed to insert training_runs row: %s", e)
|
| 169 |
+
|
| 170 |
+
# Insert per-episode metrics in batches
|
| 171 |
+
episode_rows = []
|
| 172 |
+
for m in raw_summary.get("per_episode_metrics", []):
|
| 173 |
+
episode_rows.append({
|
| 174 |
+
"run_id": run_id,
|
| 175 |
+
"step": m["step"],
|
| 176 |
+
"episode": m["episode"],
|
| 177 |
+
"reward": m.get("reward"),
|
| 178 |
+
"turns": m.get("turns", 0),
|
| 179 |
+
"intent_captured": m.get("intent_captured", False),
|
| 180 |
+
"intent_correct": m.get("intent_correct", False),
|
| 181 |
+
"true_intent": m.get("true_intent", ""),
|
| 182 |
+
"agent_intent": m.get("agent_intent", ""),
|
| 183 |
+
"injection_attempted": m.get("injection_attempted", False),
|
| 184 |
+
"injection_succeeded": m.get("injection_succeeded", False),
|
| 185 |
+
"api_call_made": m.get("api_call_made", False),
|
| 186 |
+
"api_call_correct": m.get("api_call_correct", False),
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# Batch insert (Supabase/PostgREST supports bulk inserts)
|
| 190 |
+
if episode_rows:
|
| 191 |
+
batch_size = 100
|
| 192 |
+
for i in range(0, len(episode_rows), batch_size):
|
| 193 |
+
batch = episode_rows[i : i + batch_size]
|
| 194 |
+
try:
|
| 195 |
+
client.table("training_episodes").insert(batch).execute()
|
| 196 |
+
rows_inserted += len(batch)
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error("Failed to insert episode batch %d: %s", i, e)
|
| 199 |
+
|
| 200 |
+
return rows_inserted
|
pyproject.toml
CHANGED
|
@@ -32,6 +32,9 @@ train = [
|
|
| 32 |
"accelerate>=0.27.0",
|
| 33 |
"datasets>=2.18.0",
|
| 34 |
]
|
|
|
|
|
|
|
|
|
|
| 35 |
dev = [
|
| 36 |
"pytest>=8.0",
|
| 37 |
"ruff>=0.3.0",
|
|
|
|
| 32 |
"accelerate>=0.27.0",
|
| 33 |
"datasets>=2.18.0",
|
| 34 |
]
|
| 35 |
+
upload = [
|
| 36 |
+
"supabase>=2.0.0",
|
| 37 |
+
]
|
| 38 |
dev = [
|
| 39 |
"pytest>=8.0",
|
| 40 |
"ruff>=0.3.0",
|
scripts/supabase_setup.sql
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- Supabase schema for training results
|
| 2 |
+
-- Run this in your Supabase SQL Editor to create the tables.
|
| 3 |
+
|
| 4 |
+
-- Training runs: one row per training run
|
| 5 |
+
create table if not exists training_runs (
|
| 6 |
+
id bigint generated always as identity primary key,
|
| 7 |
+
run_id text unique not null,
|
| 8 |
+
started_at timestamptz default now(),
|
| 9 |
+
duration_seconds real,
|
| 10 |
+
total_steps int,
|
| 11 |
+
total_episodes int,
|
| 12 |
+
best_step int,
|
| 13 |
+
best_mean_reward real,
|
| 14 |
+
mean_rewards jsonb, -- array of mean rewards per step
|
| 15 |
+
min_rewards jsonb, -- array of min rewards per step
|
| 16 |
+
max_rewards jsonb, -- array of max rewards per step
|
| 17 |
+
config jsonb, -- full training config snapshot
|
| 18 |
+
created_at timestamptz default now()
|
| 19 |
+
);
|
| 20 |
+
|
| 21 |
+
-- Per-episode metrics: one row per episode
|
| 22 |
+
create table if not exists training_episodes (
|
| 23 |
+
id bigint generated always as identity primary key,
|
| 24 |
+
run_id text not null references training_runs(run_id) on delete cascade,
|
| 25 |
+
step int not null,
|
| 26 |
+
episode int not null,
|
| 27 |
+
reward real,
|
| 28 |
+
turns int,
|
| 29 |
+
intent_captured boolean default false,
|
| 30 |
+
intent_correct boolean default false,
|
| 31 |
+
true_intent text,
|
| 32 |
+
agent_intent text,
|
| 33 |
+
injection_attempted boolean default false,
|
| 34 |
+
injection_succeeded boolean default false,
|
| 35 |
+
api_call_made boolean default false,
|
| 36 |
+
api_call_correct boolean default false,
|
| 37 |
+
created_at timestamptz default now()
|
| 38 |
+
);
|
| 39 |
+
|
| 40 |
+
-- Index for fast queries by run
|
| 41 |
+
create index if not exists idx_episodes_run_id on training_episodes(run_id);
|
| 42 |
+
create index if not exists idx_episodes_step on training_episodes(run_id, step);
|
| 43 |
+
|
| 44 |
+
-- Create the storage bucket (run via Supabase Dashboard > Storage > New Bucket)
|
| 45 |
+
-- Bucket name: training-results
|
| 46 |
+
-- Public: false (use service key for uploads)
|
| 47 |
+
|
| 48 |
+
-- Enable Row Level Security (optional but recommended)
|
| 49 |
+
alter table training_runs enable row level security;
|
| 50 |
+
alter table training_episodes enable row level security;
|
| 51 |
+
|
| 52 |
+
-- Allow inserts with service key (anon or service_role)
|
| 53 |
+
create policy "Allow insert training_runs" on training_runs
|
| 54 |
+
for insert with check (true);
|
| 55 |
+
create policy "Allow select training_runs" on training_runs
|
| 56 |
+
for select using (true);
|
| 57 |
+
|
| 58 |
+
create policy "Allow insert training_episodes" on training_episodes
|
| 59 |
+
for insert with check (true);
|
| 60 |
+
create policy "Allow select training_episodes" on training_episodes
|
| 61 |
+
for select using (true);
|