Tarang_v2 / storage_manager.py
unknownfriend00007's picture
Update storage_manager.py
0e0447f verified
import time
import pickle
import psycopg2
import psycopg2.extras
from config import config
class StorageManager:
"""
Neon checkpoint store with on-demand connections + simple retries.
"""
def __init__(self):
if not config.NEON_CONNECTION:
raise RuntimeError("Missing NEON_CONNECTION_STRING (or DATABASE_URL) in Space Secrets.")
self._schema_ready = False
def _connect(self):
last = None
for i in range(5):
try:
return psycopg2.connect(config.NEON_CONNECTION, connect_timeout=15)
except Exception as e:
last = e
time.sleep(2 ** i)
raise RuntimeError(f"Neon connect failed after retries: {last}")
def _init_schema_once(self):
if self._schema_ready:
return
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS tarang_checkpoints (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
model_state BYTEA NOT NULL,
optim_state BYTEA NOT NULL,
meta JSONB
);
""")
cur.execute("CREATE INDEX IF NOT EXISTS idx_tarang_ckpt_time ON tarang_checkpoints(created_at DESC);")
conn.commit()
self._schema_ready = True
finally:
conn.close()
def save_checkpoint(self, model_state: dict, optim_state: dict, meta: dict):
self._init_schema_once()
ms = pickle.dumps(model_state)
os_ = pickle.dumps(optim_state)
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO tarang_checkpoints(model_state, optim_state, meta) VALUES (%s, %s, %s)",
(psycopg2.Binary(ms), psycopg2.Binary(os_), psycopg2.extras.Json(meta)),
)
cur.execute("""
DELETE FROM tarang_checkpoints
WHERE id NOT IN (
SELECT id FROM tarang_checkpoints ORDER BY created_at DESC LIMIT 10
);
""")
conn.commit()
finally:
conn.close()
def load_latest(self):
self._init_schema_once()
conn = self._connect()
try:
with conn.cursor() as cur:
cur.execute("""
SELECT created_at, model_state, optim_state, meta
FROM tarang_checkpoints
ORDER BY created_at DESC
LIMIT 1;
""")
row = cur.fetchone()
if not row:
return None
created_at, model_b, optim_b, meta = row
return {
"created_at": created_at.isoformat() if created_at else None,
"model_state": pickle.loads(model_b),
"optim_state": pickle.loads(optim_b),
"meta": meta,
}
finally:
conn.close()