File size: 5,763 Bytes
3979178 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os
import sqlite3
import queue
import threading
from datetime import datetime
from contextlib import contextmanager
class SQLiteConnectionPool:
def __init__(self, database, max_connections=5):
# 展开 ~ 到用户主目录并获取绝对路径
db_path = os.path.abspath(os.path.expanduser(database))
# 获取目录路径
db_dir = os.path.dirname(db_path)
if db_dir: # 如果有目录部分
os.makedirs(db_dir, exist_ok=True)
self.database = db_path
self.max_connections = max_connections
self.connections = queue.Queue(maxsize=max_connections)
self.lock = threading.Lock()
# 初始化连接池
for _ in range(max_connections):
conn = sqlite3.connect(self.database, check_same_thread=False)
# 设置行工厂,返回字典格式的结果
conn.row_factory = sqlite3.Row
self.connections.put(conn)
@contextmanager
def get_connection(self):
connection = self.connections.get()
try:
yield connection
finally:
self.connections.put(connection)
def close_all(self):
while not self.connections.empty():
conn = self.connections.get()
conn.close()
# 数据库操作类
class DatabaseManager:
def __init__(self, pool):
self.pool = pool
self.create_tables()
def create_tables(self):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
# 创建用户表
cursor.execute('''
CREATE TABLE IF NOT EXISTS context_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_key TEXT NOT NULL,
chat_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
sha256_hash TEXT NOT NULL,
created_at TIMESTAMP,
updated_at TIMESTAMP
)
''')
conn.commit()
def insert_context_record(self, api_key, chat_id, parent_id, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
'INSERT INTO context_records (api_key, chat_id, parent_id, sha256_hash, created_at) VALUES (?, ?, ?, ?, ?)',
(api_key, chat_id, parent_id, sha256_hash, datetime.now())
)
conn.commit()
return cursor.lastrowid
except sqlite3.Error as e:
print(f"Error inserting context_records: {e}")
return None
def update_context_record_by_chat_id(self, api_key, chat_id, parent_id, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
'update context_records set parent_id = ?, sha256_hash = ?, updated_at = ? where api_key = ? and chat_id = ?',
(parent_id, sha256_hash, datetime.now(), api_key, chat_id)
)
conn.commit()
return cursor.lastrowid
except sqlite3.Error as e:
print(f"Error inserting context_records: {e}")
return None
def get_context_record_by_sha256_hash(self, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT * FROM context_records WHERE sha256_hash = ?', (sha256_hash,))
result = cursor.fetchone()
return dict(result) if result else None
# 使用示例
def main():
# 创建连接池
pool = SQLiteConnectionPool('~/tmp/merlin-sqlite.db', max_connections=5)
db = DatabaseManager(pool)
try:
# 创建表
db.create_tables()
# 模拟多线程操作
def worker(user_number):
username = f"user_{user_number}"
email = f"{username}@example.com"
# 插入用户
user_id = db.insert_user(username, email)
if user_id:
# 插入订单
db.insert_order(user_id, 100.50 * user_number)
db.insert_order(user_id, 200.75 * user_number)
# 查询订单
orders = db.get_user_orders(username)
print(f"Orders for {username}:")
for order in orders:
print(f"Amount: {order['amount']}, Date: {order['order_date']}")
# 创建多个线程
# threads = []
# for i in range(3):
# t = threading.Thread(target=worker, args=(i+1,))
# threads.append(t)
# t.start()
#
# # 等待所有线程完成
# for t in threads:
# t.join()
finally:
# 关闭所有连接
pool.close_all()
# 批量操作示例
def batch_insert_example(db):
with db.pool.get_connection() as conn:
cursor = conn.cursor()
try:
# 开始事务
cursor.execute('BEGIN TRANSACTION')
# 准备批量数据
users_data = [
('user1', 'user1@example.com', datetime.now()),
('user2', 'user2@example.com', datetime.now()),
('user3', 'user3@example.com', datetime.now())
]
# 批量插入
cursor.executemany(
'INSERT INTO users (username, email, created_at) VALUES (?, ?, ?)',
users_data
)
# 提交事务
conn.commit()
except sqlite3.Error as e:
print(f"Error in batch insert: {e}")
conn.rollback()
if __name__ == "__main__":
main()
|