NMFL / gpu_state_db.py
Factor Studios
Upload 43 files
520d6cf verified
import sqlite3
import json
import threading
class GPUStateDB:
def __init__(self, db_path='gpu_state.db'):
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.lock = threading.Lock()
self._init_tables()
def _init_tables(self):
with self.lock:
c = self.conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS sm (
sm_id INTEGER PRIMARY KEY,
chip_id INTEGER,
state_json TEXT
)''')
c.execute('''CREATE TABLE IF NOT EXISTS core (
core_id INTEGER PRIMARY KEY,
sm_id INTEGER,
registers BLOB,
state_json TEXT
)''')
c.execute('''CREATE TABLE IF NOT EXISTS warp (
warp_id INTEGER PRIMARY KEY,
sm_id INTEGER,
thread_ids TEXT,
state_json TEXT
)''')
c.execute('''CREATE TABLE IF NOT EXISTS thread (
thread_id INTEGER PRIMARY KEY,
warp_id INTEGER,
core_id INTEGER,
state_json TEXT
)''')
c.execute('''CREATE TABLE IF NOT EXISTS tensor_core (
tensor_core_id INTEGER PRIMARY KEY,
sm_id INTEGER,
memory BLOB,
state_json TEXT
)''')
self.conn.commit()
def save_state(self, table, id_name, id_value, state):
state_json = json.dumps(state)
with self.lock:
self.conn.execute(f"INSERT OR REPLACE INTO {table} ({id_name}, state_json) VALUES (?, ?)", (id_value, state_json))
self.conn.commit()
def load_state(self, table, id_name, id_value):
with self.lock:
cur = self.conn.execute(f"SELECT state_json FROM {table} WHERE {id_name}=?", (id_value,))
row = cur.fetchone()
return json.loads(row[0]) if row else None
def close(self):
if self.conn:
self.conn.close()
self.conn = None