File size: 2,086 Bytes
2ff82ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import sqlite3
import json
import threading

class GPUStateDB:
    def __init__(self, db_path='gpu_state.db'):
        self.conn = sqlite3.connect(db_path, check_same_thread=False)
        self.lock = threading.Lock()
        self._init_tables()

    def _init_tables(self):
        with self.lock:
            c = self.conn.cursor()
            c.execute('''CREATE TABLE IF NOT EXISTS sm (
                sm_id INTEGER PRIMARY KEY,
                chip_id INTEGER,
                state_json TEXT
            )''')
            c.execute('''CREATE TABLE IF NOT EXISTS core (
                core_id INTEGER PRIMARY KEY,
                sm_id INTEGER,
                registers BLOB,
                state_json TEXT
            )''')
            c.execute('''CREATE TABLE IF NOT EXISTS warp (
                warp_id INTEGER PRIMARY KEY,
                sm_id INTEGER,
                thread_ids TEXT,
                state_json TEXT
            )''')
            c.execute('''CREATE TABLE IF NOT EXISTS thread (
                thread_id INTEGER PRIMARY KEY,
                warp_id INTEGER,
                core_id INTEGER,
                state_json TEXT
            )''')
            c.execute('''CREATE TABLE IF NOT EXISTS tensor_core (
                tensor_core_id INTEGER PRIMARY KEY,
                sm_id INTEGER,
                memory BLOB,
                state_json TEXT
            )''')
            self.conn.commit()

    def save_state(self, table, id_name, id_value, state):
        state_json = json.dumps(state)
        with self.lock:
            self.conn.execute(f"INSERT OR REPLACE INTO {table} ({id_name}, state_json) VALUES (?, ?)", (id_value, state_json))
            self.conn.commit()

    def load_state(self, table, id_name, id_value):
        with self.lock:
            cur = self.conn.execute(f"SELECT state_json FROM {table} WHERE {id_name}=?", (id_value,))
            row = cur.fetchone()
            return json.loads(row[0]) if row else None

    def close(self):
        if self.conn:
            self.conn.close()
            self.conn = None