import os import sqlite3 import queue import threading from datetime import datetime from contextlib import contextmanager class SQLiteConnectionPool: def __init__(self, database, max_connections=5): # 展开 ~ 到用户主目录并获取绝对路径 db_path = os.path.abspath(os.path.expanduser(database)) # 获取目录路径 db_dir = os.path.dirname(db_path) if db_dir: # 如果有目录部分 os.makedirs(db_dir, exist_ok=True) self.database = db_path self.max_connections = max_connections self.connections = queue.Queue(maxsize=max_connections) self.lock = threading.Lock() # 初始化连接池 for _ in range(max_connections): conn = sqlite3.connect(self.database, check_same_thread=False) # 设置行工厂,返回字典格式的结果 conn.row_factory = sqlite3.Row self.connections.put(conn) @contextmanager def get_connection(self): connection = self.connections.get() try: yield connection finally: self.connections.put(connection) def close_all(self): while not self.connections.empty(): conn = self.connections.get() conn.close() # 数据库操作类 class DatabaseManager: def __init__(self, pool): self.pool = pool self.create_tables() def create_tables(self): with self.pool.get_connection() as conn: cursor = conn.cursor() # 创建用户表 cursor.execute(''' CREATE TABLE IF NOT EXISTS context_records ( id INTEGER PRIMARY KEY AUTOINCREMENT, api_key TEXT NOT NULL, chat_id TEXT NOT NULL, parent_id TEXT NOT NULL, sha256_hash TEXT NOT NULL, created_at TIMESTAMP, updated_at TIMESTAMP ) ''') conn.commit() def insert_context_record(self, api_key, chat_id, parent_id, sha256_hash): with self.pool.get_connection() as conn: cursor = conn.cursor() try: cursor.execute( 'INSERT INTO context_records (api_key, chat_id, parent_id, sha256_hash, created_at) VALUES (?, ?, ?, ?, ?)', (api_key, chat_id, parent_id, sha256_hash, datetime.now()) ) conn.commit() return cursor.lastrowid except sqlite3.Error as e: print(f"Error inserting context_records: {e}") return None def update_context_record_by_chat_id(self, api_key, chat_id, parent_id, sha256_hash): with self.pool.get_connection() as conn: cursor = conn.cursor() try: cursor.execute( 'update context_records set parent_id = ?, sha256_hash = ?, updated_at = ? where api_key = ? and chat_id = ?', (parent_id, sha256_hash, datetime.now(), api_key, chat_id) ) conn.commit() return cursor.lastrowid except sqlite3.Error as e: print(f"Error inserting context_records: {e}") return None def get_context_record_by_sha256_hash(self, sha256_hash): with self.pool.get_connection() as conn: cursor = conn.cursor() cursor.execute('SELECT * FROM context_records WHERE sha256_hash = ?', (sha256_hash,)) result = cursor.fetchone() return dict(result) if result else None # 使用示例 def main(): # 创建连接池 pool = SQLiteConnectionPool('~/tmp/merlin-sqlite.db', max_connections=5) db = DatabaseManager(pool) try: # 创建表 db.create_tables() # 模拟多线程操作 def worker(user_number): username = f"user_{user_number}" email = f"{username}@example.com" # 插入用户 user_id = db.insert_user(username, email) if user_id: # 插入订单 db.insert_order(user_id, 100.50 * user_number) db.insert_order(user_id, 200.75 * user_number) # 查询订单 orders = db.get_user_orders(username) print(f"Orders for {username}:") for order in orders: print(f"Amount: {order['amount']}, Date: {order['order_date']}") # 创建多个线程 # threads = [] # for i in range(3): # t = threading.Thread(target=worker, args=(i+1,)) # threads.append(t) # t.start() # # # 等待所有线程完成 # for t in threads: # t.join() finally: # 关闭所有连接 pool.close_all() # 批量操作示例 def batch_insert_example(db): with db.pool.get_connection() as conn: cursor = conn.cursor() try: # 开始事务 cursor.execute('BEGIN TRANSACTION') # 准备批量数据 users_data = [ ('user1', 'user1@example.com', datetime.now()), ('user2', 'user2@example.com', datetime.now()), ('user3', 'user3@example.com', datetime.now()) ] # 批量插入 cursor.executemany( 'INSERT INTO users (username, email, created_at) VALUES (?, ?, ?)', users_data ) # 提交事务 conn.commit() except sqlite3.Error as e: print(f"Error in batch insert: {e}") conn.rollback() if __name__ == "__main__": main()