Claude
Add SFT warm start before GRPO and DB connectivity init check
c2dc160 unverified
"""
Supabase uploader for training results — incremental mode.
Uploads after every training step so data is never lost if the job crashes.
- Creates a training_runs row at the start of training
- Upserts that row after each step with updated reward arrays
- Inserts per-episode rows after each step
Requires SUPABASE_URL and SUPABASE_KEY environment variables.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timezone
from typing import Any
logger = logging.getLogger(__name__)
def _get_client():
"""Create a Supabase client from environment variables."""
try:
from supabase import create_client
except ImportError:
logger.error(
"supabase package not installed. Install with: pip install 'nested-rl-envs[upload]'"
)
return None
url = os.environ.get("SUPABASE_URL")
key = os.environ.get("SUPABASE_KEY")
if not url or not key:
logger.error("SUPABASE_URL and SUPABASE_KEY must be set")
return None
return create_client(url, key)
class SupabaseUploader:
"""
Incremental uploader — call after_step() after each training step.
Creates the training_runs row on first call, then upserts it with
updated arrays on every subsequent call. Episode rows are inserted
immediately and never re-sent.
"""
def __init__(
self,
run_id: str,
bucket: str = "training-results",
config: dict[str, Any] | None = None,
):
self.run_id = run_id
self.bucket = bucket
self.config = config
self._client = _get_client()
self._run_created = False
# Accumulated arrays (mirrors what training_runs stores)
self._mean_rewards: list[float] = []
self._min_rewards: list[float] = []
self._max_rewards: list[float] = []
self._total_episodes = 0
self._started_at = datetime.now(timezone.utc).isoformat()
if self._client:
logger.info("SupabaseUploader ready: run_id=%s", run_id)
self._write_init_row()
else:
logger.warning("SupabaseUploader: no client — uploads will be skipped")
def _write_init_row(self):
"""Write an init row to verify DB connectivity at startup."""
try:
run_row = {
"run_id": self.run_id,
"started_at": self._started_at,
"duration_seconds": None,
"total_steps": 0,
"total_episodes": 0,
"best_step": 0,
"best_mean_reward": 0.0,
"mean_rewards": [],
"min_rewards": [],
"max_rewards": [],
"config": self.config,
}
self._client.table("training_runs").upsert(
run_row, on_conflict="run_id"
).execute()
self._run_created = True
logger.info("DB init row written successfully (run_id=%s)", self.run_id)
except Exception as e:
logger.error("DB init row FAILED — check connection: %s", e)
@property
def enabled(self) -> bool:
return self._client is not None
def after_step(self, step: int, eval_result: dict[str, Any], prompt: str):
"""
Called after each training step/candidate evaluation.
Upserts the training_runs row and inserts new episode rows.
"""
if not self._client:
return
mean_reward = eval_result.get("mean_reward", 0.0)
min_reward = eval_result.get("min_reward", 0.0)
max_reward = eval_result.get("max_reward", 0.0)
self._mean_rewards.append(mean_reward)
self._min_rewards.append(min_reward)
self._max_rewards.append(max_reward)
num_episodes = eval_result.get("num_episodes", 0)
self._total_episodes += num_episodes
# Best so far
best_mean = max(self._mean_rewards)
best_idx = self._mean_rewards.index(best_mean)
# --- Upsert training_runs row ---
run_row = {
"run_id": self.run_id,
"started_at": self._started_at,
"duration_seconds": None, # updated at end
"total_steps": len(self._mean_rewards),
"total_episodes": self._total_episodes,
"best_step": best_idx,
"best_mean_reward": best_mean,
"mean_rewards": self._mean_rewards,
"min_rewards": self._min_rewards,
"max_rewards": self._max_rewards,
"config": self.config,
}
try:
self._client.table("training_runs").upsert(
run_row, on_conflict="run_id"
).execute()
self._run_created = True
logger.info(
"Upserted training_runs: step=%d mean_reward=%.1f",
step, mean_reward,
)
except Exception as e:
logger.error("Failed to upsert training_runs: %s", e)
# --- Insert episode rows for this step ---
episode_rows = []
rewards_list = eval_result.get("rewards", [])
for ei, log in enumerate(eval_result.get("logs", [])):
episode_rows.append({
"run_id": self.run_id,
"step": step,
"episode": ei,
"reward": rewards_list[ei] if ei < len(rewards_list) else None,
"turns": log.get("turns", 0),
"intent_captured": log.get("intent_captured", False),
"intent_correct": log.get("intent_correct", False),
"true_intent": log.get("true_intent", ""),
"agent_intent": log.get("agent_intent", ""),
"injection_attempted": log.get("injection_attempted", False),
"injection_succeeded": log.get("injection_succeeded", False),
"api_call_made": log.get("api_call_made", False),
"api_call_correct": log.get("api_call_correct", False),
})
if episode_rows:
try:
self._client.table("training_episodes").insert(episode_rows).execute()
logger.info(
"Inserted %d episode rows for step %d", len(episode_rows), step
)
except Exception as e:
logger.error("Failed to insert episodes for step %d: %s", step, e)
def finish(
self,
duration_seconds: float | None = None,
report_path: str | None = None,
chart_path: str | None = None,
raw_summary: dict[str, Any] | None = None,
):
"""
Called at end of training. Updates duration and uploads final files.
"""
if not self._client:
return
# Update duration on the run row
if duration_seconds is not None and self._run_created:
try:
self._client.table("training_runs").update(
{"duration_seconds": duration_seconds}
).eq("run_id", self.run_id).execute()
logger.info("Updated duration: %.1fs", duration_seconds)
except Exception as e:
logger.error("Failed to update duration: %s", e)
# Upload files to Storage
if raw_summary:
self._upload_file(
f"{self.run_id}/raw_summary.json",
json.dumps(raw_summary, indent=2, default=str).encode(),
"application/json",
)
if report_path and os.path.exists(report_path):
with open(report_path, "rb") as f:
self._upload_file(
f"{self.run_id}/report.md", f.read(), "text/markdown"
)
if chart_path and os.path.exists(chart_path):
with open(chart_path, "rb") as f:
self._upload_file(
f"{self.run_id}/reward_chart.png", f.read(), "image/png"
)
def _upload_file(self, path: str, data: bytes, content_type: str):
"""Upload a single file to Supabase Storage."""
try:
self._client.storage.from_(self.bucket).upload(
path, data, {"content-type": content_type}
)
logger.info("Uploaded %s to storage", path)
except Exception as e:
logger.error("Failed to upload %s: %s", path, e)