| import sqlite3 |
| import os |
| import time |
| from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation |
| from . import BaseDatabase |
| from typing import Tuple |
|
|
|
|
| class SQLiteDatabase(BaseDatabase): |
| def __init__(self, db_path: str) -> None: |
| super().__init__() |
| self.db_path = db_path |
|
|
| with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f: |
| sql = f.read() |
|
|
| |
| self.conn = self._get_conn(self.db_path) |
| c = self.conn.cursor() |
| c.executescript(sql) |
| self.conn.commit() |
|
|
| |
| c.execute( |
| """ |
| PRAGMA table_info(webchat_conversation) |
| """ |
| ) |
| res = c.fetchall() |
| has_title = False |
| has_persona_id = False |
| for row in res: |
| if row[1] == "title": |
| has_title = True |
| if row[1] == "persona_id": |
| has_persona_id = True |
| if not has_title: |
| c.execute( |
| """ |
| ALTER TABLE webchat_conversation ADD COLUMN title TEXT; |
| """ |
| ) |
| self.conn.commit() |
| if not has_persona_id: |
| c.execute( |
| """ |
| ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; |
| """ |
| ) |
| self.conn.commit() |
|
|
| c.close() |
|
|
| def _get_conn(self, db_path: str) -> sqlite3.Connection: |
| conn = sqlite3.connect(self.db_path) |
| conn.text_factory = str |
| return conn |
|
|
| def _exec_sql(self, sql: str, params: Tuple = None): |
| conn = self.conn |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| conn = self._get_conn(self.db_path) |
| c = conn.cursor() |
|
|
| if params: |
| c.execute(sql, params) |
| c.close() |
| else: |
| c.execute(sql) |
| c.close() |
|
|
| conn.commit() |
|
|
| def insert_platform_metrics(self, metrics: dict): |
| for k, v in metrics.items(): |
| self._exec_sql( |
| """ |
| INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) |
| """, |
| (k, v, int(time.time())), |
| ) |
|
|
| def insert_plugin_metrics(self, metrics: dict): |
| pass |
|
|
| def insert_command_metrics(self, metrics: dict): |
| for k, v in metrics.items(): |
| self._exec_sql( |
| """ |
| INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?) |
| """, |
| (k, v, int(time.time())), |
| ) |
|
|
| def insert_llm_metrics(self, metrics: dict): |
| for k, v in metrics.items(): |
| self._exec_sql( |
| """ |
| INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) |
| """, |
| (k, v, int(time.time())), |
| ) |
|
|
| def update_llm_history(self, session_id: str, content: str, provider_type: str): |
| res = self.get_llm_history(session_id, provider_type) |
| if res: |
| self._exec_sql( |
| """ |
| UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ? |
| """, |
| (content, session_id, provider_type), |
| ) |
| else: |
| self._exec_sql( |
| """ |
| INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?) |
| """, |
| (provider_type, session_id, content), |
| ) |
|
|
| def get_llm_history( |
| self, session_id: str = None, provider_type: str = None |
| ) -> Tuple: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| where_clause = "" |
| if session_id or provider_type: |
| where_clause += " WHERE " |
| has = False |
| if session_id: |
| where_clause += f"session_id = '{session_id}'" |
| has = True |
| if provider_type: |
| if has: |
| where_clause += " AND " |
| where_clause += f"provider_type = '{provider_type}'" |
|
|
| c.execute( |
| """ |
| SELECT * FROM llm_history |
| """ |
| + where_clause |
| ) |
| res = c.fetchall() |
| histories = [] |
| for row in res: |
| histories.append(LLMHistory(*row)) |
| c.close() |
| return histories |
|
|
| def get_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取 offset_sec 秒前到现在的基础统计数据""" |
| where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" |
|
|
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM platform |
| """ |
| + where_clause |
| ) |
|
|
| platform = [] |
| for row in c.fetchall(): |
| platform.append(Platform(*row)) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| c.close() |
|
|
| return Stats(platform, [], []) |
|
|
| def get_total_message_count(self) -> int: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT SUM(count) FROM platform |
| """ |
| ) |
| res = c.fetchone() |
| c.close() |
| return res[0] |
|
|
| def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取 offset_sec 秒前到现在的基础统计数据(合并)""" |
| where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" |
|
|
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT name, SUM(count), timestamp FROM platform |
| """ |
| + where_clause |
| + " GROUP BY name" |
| ) |
|
|
| platform = [] |
| for row in c.fetchall(): |
| platform.append(Platform(*row)) |
|
|
| c.close() |
|
|
| return Stats(platform, [], []) |
|
|
| def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? |
| """, |
| (user_id, cid), |
| ) |
|
|
| res = c.fetchone() |
| c.close() |
|
|
| if not res: |
| return |
|
|
| return Conversation(*res) |
|
|
| def new_conversation(self, user_id: str, cid: str): |
| history = "[]" |
| updated_at = int(time.time()) |
| created_at = updated_at |
| self._exec_sql( |
| """ |
| INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) |
| """, |
| (user_id, cid, history, updated_at, created_at), |
| ) |
|
|
| def get_conversations(self, user_id: str) -> Tuple: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC |
| """, |
| (user_id,), |
| ) |
|
|
| res = c.fetchall() |
| c.close() |
| conversations = [] |
| for row in res: |
| cid = row[0] |
| created_at = row[1] |
| updated_at = row[2] |
| title = row[3] |
| persona_id = row[4] |
| conversations.append( |
| Conversation("", cid, "[]", created_at, updated_at, title, persona_id) |
| ) |
| return conversations |
|
|
| def update_conversation(self, user_id: str, cid: str, history: str): |
| """更新对话,并且同时更新时间""" |
| updated_at = int(time.time()) |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? |
| """, |
| (history, updated_at, user_id, cid), |
| ) |
|
|
| def update_conversation_title(self, user_id: str, cid: str, title: str): |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? |
| """, |
| (title, user_id, cid), |
| ) |
|
|
| def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? |
| """, |
| (persona_id, user_id, cid), |
| ) |
|
|
| def delete_conversation(self, user_id: str, cid: str): |
| self._exec_sql( |
| """ |
| DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? |
| """, |
| (user_id, cid), |
| ) |
|
|
| def insert_atri_vision_data(self, vision: ATRIVision): |
| ts = int(time.time()) |
| keywords = ",".join(vision.keywords) |
| self._exec_sql( |
| """ |
| INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) |
| """, |
| ( |
| vision.id, |
| vision.url_or_path, |
| vision.caption, |
| vision.is_meme, |
| keywords, |
| vision.platform_name, |
| vision.session_id, |
| vision.sender_nickname, |
| ts, |
| ), |
| ) |
|
|
| def get_atri_vision_data(self) -> Tuple: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM atri_vision |
| """ |
| ) |
|
|
| res = c.fetchall() |
| visions = [] |
| for row in res: |
| visions.append(ATRIVision(*row)) |
| c.close() |
| return visions |
|
|
| def get_atri_vision_data_by_path_or_id( |
| self, url_or_path: str, id: str |
| ) -> ATRIVision: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ? |
| """, |
| (url_or_path, id), |
| ) |
|
|
| res = c.fetchone() |
| c.close() |
| if res: |
| return ATRIVision(*res) |
| return None |
|
|