| import time | |
| import pickle | |
| import psycopg2 | |
| import psycopg2.extras | |
| from config import config | |
| class StorageManager: | |
| """ | |
| Neon checkpoint store with on-demand connections + simple retries. | |
| """ | |
| def __init__(self): | |
| if not config.NEON_CONNECTION: | |
| raise RuntimeError("Missing NEON_CONNECTION_STRING (or DATABASE_URL) in Space Secrets.") | |
| self._schema_ready = False | |
| def _connect(self): | |
| last = None | |
| for i in range(5): | |
| try: | |
| return psycopg2.connect(config.NEON_CONNECTION, connect_timeout=15) | |
| except Exception as e: | |
| last = e | |
| time.sleep(2 ** i) | |
| raise RuntimeError(f"Neon connect failed after retries: {last}") | |
| def _init_schema_once(self): | |
| if self._schema_ready: | |
| return | |
| conn = self._connect() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS tarang_checkpoints ( | |
| id SERIAL PRIMARY KEY, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| model_state BYTEA NOT NULL, | |
| optim_state BYTEA NOT NULL, | |
| meta JSONB | |
| ); | |
| """) | |
| cur.execute("CREATE INDEX IF NOT EXISTS idx_tarang_ckpt_time ON tarang_checkpoints(created_at DESC);") | |
| conn.commit() | |
| self._schema_ready = True | |
| finally: | |
| conn.close() | |
| def save_checkpoint(self, model_state: dict, optim_state: dict, meta: dict): | |
| self._init_schema_once() | |
| ms = pickle.dumps(model_state) | |
| os_ = pickle.dumps(optim_state) | |
| conn = self._connect() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| "INSERT INTO tarang_checkpoints(model_state, optim_state, meta) VALUES (%s, %s, %s)", | |
| (psycopg2.Binary(ms), psycopg2.Binary(os_), psycopg2.extras.Json(meta)), | |
| ) | |
| cur.execute(""" | |
| DELETE FROM tarang_checkpoints | |
| WHERE id NOT IN ( | |
| SELECT id FROM tarang_checkpoints ORDER BY created_at DESC LIMIT 10 | |
| ); | |
| """) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def load_latest(self): | |
| self._init_schema_once() | |
| conn = self._connect() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT created_at, model_state, optim_state, meta | |
| FROM tarang_checkpoints | |
| ORDER BY created_at DESC | |
| LIMIT 1; | |
| """) | |
| row = cur.fetchone() | |
| if not row: | |
| return None | |
| created_at, model_b, optim_b, meta = row | |
| return { | |
| "created_at": created_at.isoformat() if created_at else None, | |
| "model_state": pickle.loads(model_b), | |
| "optim_state": pickle.loads(optim_b), | |
| "meta": meta, | |
| } | |
| finally: | |
| conn.close() | |