File size: 8,406 Bytes
28bcb40
76f180f
28bcb40
76f180f
 
 
 
 
28bcb40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76f180f
28bcb40
76f180f
 
 
 
 
28bcb40
 
76f180f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2dc160
76f180f
 
 
c2dc160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76f180f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28bcb40
 
76f180f
 
 
 
 
 
 
 
28bcb40
76f180f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28bcb40
76f180f
 
 
 
 
 
 
28bcb40
76f180f
 
 
 
 
28bcb40
76f180f
 
 
 
 
28bcb40
76f180f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
Supabase uploader for training results — incremental mode.

Uploads after every training step so data is never lost if the job crashes.

- Creates a training_runs row at the start of training
- Upserts that row after each step with updated reward arrays
- Inserts per-episode rows after each step

Requires SUPABASE_URL and SUPABASE_KEY environment variables.
"""

from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timezone
from typing import Any

logger = logging.getLogger(__name__)


def _get_client():
    """Create a Supabase client from environment variables."""
    try:
        from supabase import create_client
    except ImportError:
        logger.error(
            "supabase package not installed. Install with: pip install 'nested-rl-envs[upload]'"
        )
        return None

    url = os.environ.get("SUPABASE_URL")
    key = os.environ.get("SUPABASE_KEY")
    if not url or not key:
        logger.error("SUPABASE_URL and SUPABASE_KEY must be set")
        return None

    return create_client(url, key)


class SupabaseUploader:
    """
    Incremental uploader — call after_step() after each training step.

    Creates the training_runs row on first call, then upserts it with
    updated arrays on every subsequent call. Episode rows are inserted
    immediately and never re-sent.
    """

    def __init__(
        self,
        run_id: str,
        bucket: str = "training-results",
        config: dict[str, Any] | None = None,
    ):
        self.run_id = run_id
        self.bucket = bucket
        self.config = config
        self._client = _get_client()
        self._run_created = False

        # Accumulated arrays (mirrors what training_runs stores)
        self._mean_rewards: list[float] = []
        self._min_rewards: list[float] = []
        self._max_rewards: list[float] = []
        self._total_episodes = 0
        self._started_at = datetime.now(timezone.utc).isoformat()

        if self._client:
            logger.info("SupabaseUploader ready: run_id=%s", run_id)
            self._write_init_row()
        else:
            logger.warning("SupabaseUploader: no client — uploads will be skipped")

    def _write_init_row(self):
        """Write an init row to verify DB connectivity at startup."""
        try:
            run_row = {
                "run_id": self.run_id,
                "started_at": self._started_at,
                "duration_seconds": None,
                "total_steps": 0,
                "total_episodes": 0,
                "best_step": 0,
                "best_mean_reward": 0.0,
                "mean_rewards": [],
                "min_rewards": [],
                "max_rewards": [],
                "config": self.config,
            }
            self._client.table("training_runs").upsert(
                run_row, on_conflict="run_id"
            ).execute()
            self._run_created = True
            logger.info("DB init row written successfully (run_id=%s)", self.run_id)
        except Exception as e:
            logger.error("DB init row FAILED — check connection: %s", e)

    @property
    def enabled(self) -> bool:
        return self._client is not None

    def after_step(self, step: int, eval_result: dict[str, Any], prompt: str):
        """
        Called after each training step/candidate evaluation.

        Upserts the training_runs row and inserts new episode rows.
        """
        if not self._client:
            return

        mean_reward = eval_result.get("mean_reward", 0.0)
        min_reward = eval_result.get("min_reward", 0.0)
        max_reward = eval_result.get("max_reward", 0.0)

        self._mean_rewards.append(mean_reward)
        self._min_rewards.append(min_reward)
        self._max_rewards.append(max_reward)

        num_episodes = eval_result.get("num_episodes", 0)
        self._total_episodes += num_episodes

        # Best so far
        best_mean = max(self._mean_rewards)
        best_idx = self._mean_rewards.index(best_mean)

        # --- Upsert training_runs row ---
        run_row = {
            "run_id": self.run_id,
            "started_at": self._started_at,
            "duration_seconds": None,  # updated at end
            "total_steps": len(self._mean_rewards),
            "total_episodes": self._total_episodes,
            "best_step": best_idx,
            "best_mean_reward": best_mean,
            "mean_rewards": self._mean_rewards,
            "min_rewards": self._min_rewards,
            "max_rewards": self._max_rewards,
            "config": self.config,
        }

        try:
            self._client.table("training_runs").upsert(
                run_row, on_conflict="run_id"
            ).execute()
            self._run_created = True
            logger.info(
                "Upserted training_runs: step=%d mean_reward=%.1f",
                step, mean_reward,
            )
        except Exception as e:
            logger.error("Failed to upsert training_runs: %s", e)

        # --- Insert episode rows for this step ---
        episode_rows = []
        rewards_list = eval_result.get("rewards", [])
        for ei, log in enumerate(eval_result.get("logs", [])):
            episode_rows.append({
                "run_id": self.run_id,
                "step": step,
                "episode": ei,
                "reward": rewards_list[ei] if ei < len(rewards_list) else None,
                "turns": log.get("turns", 0),
                "intent_captured": log.get("intent_captured", False),
                "intent_correct": log.get("intent_correct", False),
                "true_intent": log.get("true_intent", ""),
                "agent_intent": log.get("agent_intent", ""),
                "injection_attempted": log.get("injection_attempted", False),
                "injection_succeeded": log.get("injection_succeeded", False),
                "api_call_made": log.get("api_call_made", False),
                "api_call_correct": log.get("api_call_correct", False),
            })

        if episode_rows:
            try:
                self._client.table("training_episodes").insert(episode_rows).execute()
                logger.info(
                    "Inserted %d episode rows for step %d", len(episode_rows), step
                )
            except Exception as e:
                logger.error("Failed to insert episodes for step %d: %s", step, e)

    def finish(
        self,
        duration_seconds: float | None = None,
        report_path: str | None = None,
        chart_path: str | None = None,
        raw_summary: dict[str, Any] | None = None,
    ):
        """
        Called at end of training. Updates duration and uploads final files.
        """
        if not self._client:
            return

        # Update duration on the run row
        if duration_seconds is not None and self._run_created:
            try:
                self._client.table("training_runs").update(
                    {"duration_seconds": duration_seconds}
                ).eq("run_id", self.run_id).execute()
                logger.info("Updated duration: %.1fs", duration_seconds)
            except Exception as e:
                logger.error("Failed to update duration: %s", e)

        # Upload files to Storage
        if raw_summary:
            self._upload_file(
                f"{self.run_id}/raw_summary.json",
                json.dumps(raw_summary, indent=2, default=str).encode(),
                "application/json",
            )

        if report_path and os.path.exists(report_path):
            with open(report_path, "rb") as f:
                self._upload_file(
                    f"{self.run_id}/report.md", f.read(), "text/markdown"
                )

        if chart_path and os.path.exists(chart_path):
            with open(chart_path, "rb") as f:
                self._upload_file(
                    f"{self.run_id}/reward_chart.png", f.read(), "image/png"
                )

    def _upload_file(self, path: str, data: bytes, content_type: str):
        """Upload a single file to Supabase Storage."""
        try:
            self._client.storage.from_(self.bucket).upload(
                path, data, {"content-type": content_type}
            )
            logger.info("Uploaded %s to storage", path)
        except Exception as e:
            logger.error("Failed to upload %s: %s", path, e)