ml / core /sqlite_store.py
devin15's picture
Upload 31 files
3979178 verified
import os
import sqlite3
import queue
import threading
from datetime import datetime
from contextlib import contextmanager
class SQLiteConnectionPool:
def __init__(self, database, max_connections=5):
# 展开 ~ 到用户主目录并获取绝对路径
db_path = os.path.abspath(os.path.expanduser(database))
# 获取目录路径
db_dir = os.path.dirname(db_path)
if db_dir: # 如果有目录部分
os.makedirs(db_dir, exist_ok=True)
self.database = db_path
self.max_connections = max_connections
self.connections = queue.Queue(maxsize=max_connections)
self.lock = threading.Lock()
# 初始化连接池
for _ in range(max_connections):
conn = sqlite3.connect(self.database, check_same_thread=False)
# 设置行工厂,返回字典格式的结果
conn.row_factory = sqlite3.Row
self.connections.put(conn)
@contextmanager
def get_connection(self):
connection = self.connections.get()
try:
yield connection
finally:
self.connections.put(connection)
def close_all(self):
while not self.connections.empty():
conn = self.connections.get()
conn.close()
# 数据库操作类
class DatabaseManager:
def __init__(self, pool):
self.pool = pool
self.create_tables()
def create_tables(self):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
# 创建用户表
cursor.execute('''
CREATE TABLE IF NOT EXISTS context_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_key TEXT NOT NULL,
chat_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
sha256_hash TEXT NOT NULL,
created_at TIMESTAMP,
updated_at TIMESTAMP
)
''')
conn.commit()
def insert_context_record(self, api_key, chat_id, parent_id, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
'INSERT INTO context_records (api_key, chat_id, parent_id, sha256_hash, created_at) VALUES (?, ?, ?, ?, ?)',
(api_key, chat_id, parent_id, sha256_hash, datetime.now())
)
conn.commit()
return cursor.lastrowid
except sqlite3.Error as e:
print(f"Error inserting context_records: {e}")
return None
def update_context_record_by_chat_id(self, api_key, chat_id, parent_id, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
try:
cursor.execute(
'update context_records set parent_id = ?, sha256_hash = ?, updated_at = ? where api_key = ? and chat_id = ?',
(parent_id, sha256_hash, datetime.now(), api_key, chat_id)
)
conn.commit()
return cursor.lastrowid
except sqlite3.Error as e:
print(f"Error inserting context_records: {e}")
return None
def get_context_record_by_sha256_hash(self, sha256_hash):
with self.pool.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT * FROM context_records WHERE sha256_hash = ?', (sha256_hash,))
result = cursor.fetchone()
return dict(result) if result else None
# 使用示例
def main():
# 创建连接池
pool = SQLiteConnectionPool('~/tmp/merlin-sqlite.db', max_connections=5)
db = DatabaseManager(pool)
try:
# 创建表
db.create_tables()
# 模拟多线程操作
def worker(user_number):
username = f"user_{user_number}"
email = f"{username}@example.com"
# 插入用户
user_id = db.insert_user(username, email)
if user_id:
# 插入订单
db.insert_order(user_id, 100.50 * user_number)
db.insert_order(user_id, 200.75 * user_number)
# 查询订单
orders = db.get_user_orders(username)
print(f"Orders for {username}:")
for order in orders:
print(f"Amount: {order['amount']}, Date: {order['order_date']}")
# 创建多个线程
# threads = []
# for i in range(3):
# t = threading.Thread(target=worker, args=(i+1,))
# threads.append(t)
# t.start()
#
# # 等待所有线程完成
# for t in threads:
# t.join()
finally:
# 关闭所有连接
pool.close_all()
# 批量操作示例
def batch_insert_example(db):
with db.pool.get_connection() as conn:
cursor = conn.cursor()
try:
# 开始事务
cursor.execute('BEGIN TRANSACTION')
# 准备批量数据
users_data = [
('user1', 'user1@example.com', datetime.now()),
('user2', 'user2@example.com', datetime.now()),
('user3', 'user3@example.com', datetime.now())
]
# 批量插入
cursor.executemany(
'INSERT INTO users (username, email, created_at) VALUES (?, ?, ?)',
users_data
)
# 提交事务
conn.commit()
except sqlite3.Error as e:
print(f"Error in batch insert: {e}")
conn.rollback()
if __name__ == "__main__":
main()