import time import pickle import psycopg2 import psycopg2.extras from config import config class StorageManager: """ Neon checkpoint store with on-demand connections + simple retries. """ def __init__(self): if not config.NEON_CONNECTION: raise RuntimeError("Missing NEON_CONNECTION_STRING (or DATABASE_URL) in Space Secrets.") self._schema_ready = False def _connect(self): last = None for i in range(5): try: return psycopg2.connect(config.NEON_CONNECTION, connect_timeout=15) except Exception as e: last = e time.sleep(2 ** i) raise RuntimeError(f"Neon connect failed after retries: {last}") def _init_schema_once(self): if self._schema_ready: return conn = self._connect() try: with conn.cursor() as cur: cur.execute(""" CREATE TABLE IF NOT EXISTS tarang_checkpoints ( id SERIAL PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, model_state BYTEA NOT NULL, optim_state BYTEA NOT NULL, meta JSONB ); """) cur.execute("CREATE INDEX IF NOT EXISTS idx_tarang_ckpt_time ON tarang_checkpoints(created_at DESC);") conn.commit() self._schema_ready = True finally: conn.close() def save_checkpoint(self, model_state: dict, optim_state: dict, meta: dict): self._init_schema_once() ms = pickle.dumps(model_state) os_ = pickle.dumps(optim_state) conn = self._connect() try: with conn.cursor() as cur: cur.execute( "INSERT INTO tarang_checkpoints(model_state, optim_state, meta) VALUES (%s, %s, %s)", (psycopg2.Binary(ms), psycopg2.Binary(os_), psycopg2.extras.Json(meta)), ) cur.execute(""" DELETE FROM tarang_checkpoints WHERE id NOT IN ( SELECT id FROM tarang_checkpoints ORDER BY created_at DESC LIMIT 10 ); """) conn.commit() finally: conn.close() def load_latest(self): self._init_schema_once() conn = self._connect() try: with conn.cursor() as cur: cur.execute(""" SELECT created_at, model_state, optim_state, meta FROM tarang_checkpoints ORDER BY created_at DESC LIMIT 1; """) row = cur.fetchone() if not row: return None created_at, model_b, optim_b, meta = row return { "created_at": created_at.isoformat() if created_at else None, "model_state": pickle.loads(model_b), "optim_state": pickle.loads(optim_b), "meta": meta, } finally: conn.close()