File size: 5,763 Bytes
3979178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import sqlite3
import queue
import threading
from datetime import datetime
from contextlib import contextmanager

class SQLiteConnectionPool:
    def __init__(self, database, max_connections=5):
        # 展开 ~ 到用户主目录并获取绝对路径
        db_path = os.path.abspath(os.path.expanduser(database))
        # 获取目录路径
        db_dir = os.path.dirname(db_path)
        if db_dir:  # 如果有目录部分
            os.makedirs(db_dir, exist_ok=True)

        self.database = db_path
        self.max_connections = max_connections
        self.connections = queue.Queue(maxsize=max_connections)
        self.lock = threading.Lock()

        # 初始化连接池
        for _ in range(max_connections):
            conn = sqlite3.connect(self.database, check_same_thread=False)
            # 设置行工厂,返回字典格式的结果
            conn.row_factory = sqlite3.Row
            self.connections.put(conn)

    @contextmanager
    def get_connection(self):
        connection = self.connections.get()
        try:
            yield connection
        finally:
            self.connections.put(connection)

    def close_all(self):
        while not self.connections.empty():
            conn = self.connections.get()
            conn.close()

# 数据库操作类
class DatabaseManager:
    def __init__(self, pool):
        self.pool = pool
        self.create_tables()

    def create_tables(self):
        with self.pool.get_connection() as conn:
            cursor = conn.cursor()

            # 创建用户表
            cursor.execute('''
            CREATE TABLE IF NOT EXISTS context_records (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                api_key TEXT NOT NULL,
                chat_id TEXT NOT NULL,
                parent_id TEXT NOT NULL,
                sha256_hash TEXT NOT NULL,
                created_at TIMESTAMP,
                updated_at TIMESTAMP
            )
            ''')

            conn.commit()

    def insert_context_record(self, api_key, chat_id, parent_id, sha256_hash):
        with self.pool.get_connection() as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(
                    'INSERT INTO context_records (api_key, chat_id, parent_id, sha256_hash, created_at) VALUES (?, ?, ?, ?, ?)',
                    (api_key, chat_id, parent_id, sha256_hash, datetime.now())
                )
                conn.commit()
                return cursor.lastrowid
            except sqlite3.Error as e:
                print(f"Error inserting context_records: {e}")
                return None

    def update_context_record_by_chat_id(self, api_key, chat_id, parent_id, sha256_hash):
        with self.pool.get_connection() as conn:
            cursor = conn.cursor()
            try:
                cursor.execute(
                    'update context_records set parent_id = ?, sha256_hash = ?, updated_at = ? where api_key = ? and chat_id = ?',
                    (parent_id, sha256_hash, datetime.now(), api_key, chat_id)
                )
                conn.commit()
                return cursor.lastrowid
            except sqlite3.Error as e:
                print(f"Error inserting context_records: {e}")
                return None

    def get_context_record_by_sha256_hash(self, sha256_hash):
        with self.pool.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute('SELECT * FROM context_records WHERE sha256_hash = ?', (sha256_hash,))
            result = cursor.fetchone()
            return dict(result) if result else None

# 使用示例
def main():
    # 创建连接池

    pool = SQLiteConnectionPool('~/tmp/merlin-sqlite.db', max_connections=5)
    db = DatabaseManager(pool)

    try:
        # 创建表
        db.create_tables()

        # 模拟多线程操作
        def worker(user_number):
            username = f"user_{user_number}"
            email = f"{username}@example.com"

            # 插入用户
            user_id = db.insert_user(username, email)
            if user_id:
                # 插入订单
                db.insert_order(user_id, 100.50 * user_number)
                db.insert_order(user_id, 200.75 * user_number)

                # 查询订单
                orders = db.get_user_orders(username)
                print(f"Orders for {username}:")
                for order in orders:
                    print(f"Amount: {order['amount']}, Date: {order['order_date']}")

        # 创建多个线程
        # threads = []
        # for i in range(3):
        #     t = threading.Thread(target=worker, args=(i+1,))
        #     threads.append(t)
        #     t.start()
        #
        # # 等待所有线程完成
        # for t in threads:
        #     t.join()

    finally:
        # 关闭所有连接
        pool.close_all()

# 批量操作示例
def batch_insert_example(db):
    with db.pool.get_connection() as conn:
        cursor = conn.cursor()
        try:
            # 开始事务
            cursor.execute('BEGIN TRANSACTION')

            # 准备批量数据
            users_data = [
                ('user1', 'user1@example.com', datetime.now()),
                ('user2', 'user2@example.com', datetime.now()),
                ('user3', 'user3@example.com', datetime.now())
            ]

            # 批量插入
            cursor.executemany(
                'INSERT INTO users (username, email, created_at) VALUES (?, ?, ?)',
                users_data
            )

            # 提交事务
            conn.commit()
        except sqlite3.Error as e:
            print(f"Error in batch insert: {e}")
            conn.rollback()

if __name__ == "__main__":
    main()