import sqlite3 import json import threading class GPUStateDB: def __init__(self, db_path='gpu_state.db'): self.conn = sqlite3.connect(db_path, check_same_thread=False) self.lock = threading.Lock() self._init_tables() def _init_tables(self): with self.lock: c = self.conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS sm ( sm_id INTEGER PRIMARY KEY, chip_id INTEGER, state_json TEXT )''') c.execute('''CREATE TABLE IF NOT EXISTS core ( core_id INTEGER PRIMARY KEY, sm_id INTEGER, registers BLOB, state_json TEXT )''') c.execute('''CREATE TABLE IF NOT EXISTS warp ( warp_id INTEGER PRIMARY KEY, sm_id INTEGER, thread_ids TEXT, state_json TEXT )''') c.execute('''CREATE TABLE IF NOT EXISTS thread ( thread_id INTEGER PRIMARY KEY, warp_id INTEGER, core_id INTEGER, state_json TEXT )''') c.execute('''CREATE TABLE IF NOT EXISTS tensor_core ( tensor_core_id INTEGER PRIMARY KEY, sm_id INTEGER, memory BLOB, state_json TEXT )''') self.conn.commit() def save_state(self, table, id_name, id_value, state): state_json = json.dumps(state) with self.lock: self.conn.execute(f"INSERT OR REPLACE INTO {table} ({id_name}, state_json) VALUES (?, ?)", (id_value, state_json)) self.conn.commit() def load_state(self, table, id_name, id_value): with self.lock: cur = self.conn.execute(f"SELECT state_json FROM {table} WHERE {id_name}=?", (id_value,)) row = cur.fetchone() return json.loads(row[0]) if row else None def close(self): if self.conn: self.conn.close() self.conn = None