import asyncpg import logging import time from typing import Optional, List, Dict, Any from config import config logger = logging.getLogger(__name__) class Database: def __init__(self) -> None: self.pool: Optional[asyncpg.Pool] = None async def connect(self) -> None: self.pool = await asyncpg.create_pool( dsn=config.DATABASE_URL, min_size=1, max_size=5, # увеличено для масштабируемости command_timeout=60, server_settings={ "jit": "off", # отключаем JIT для стабильности "application_name": "glm_bot", }, ) logger.info("Database pool created (max_size=5)") await self._create_tables() await self._create_indexes() async def disconnect(self) -> None: if self.pool: await self.pool.close() logger.info("Database pool closed") def _acquire(self): if self.pool is None: raise RuntimeError("Database not connected. Call connect() first.") return self.pool.acquire() async def _create_tables(self) -> None: async with self._acquire() as conn: # Users table await conn.execute(""" CREATE TABLE IF NOT EXISTS users ( id BIGINT PRIMARY KEY, username VARCHAR(255), first_name VARCHAR(255), last_name VARCHAR(255), created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), settings JSONB DEFAULT '{}'::jsonb ) """) # Messages table await conn.execute(""" CREATE TABLE IF NOT EXISTS messages ( id SERIAL PRIMARY KEY, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, role VARCHAR(20) NOT NULL CHECK (role IN ('user', 'assistant', 'system')), content TEXT NOT NULL, tokens_used INTEGER DEFAULT 0, is_summarized BOOLEAN DEFAULT FALSE, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ) """) # Summaries table await conn.execute(""" CREATE TABLE IF NOT EXISTS summaries ( id SERIAL PRIMARY KEY, user_id BIGINT NOT NULL UNIQUE REFERENCES users(id) ON DELETE CASCADE, summary TEXT NOT NULL, message_count INTEGER NOT NULL DEFAULT 0, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ) """) # Metrics table — для мониторинга await conn.execute(""" CREATE TABLE IF NOT EXISTS metrics ( id SERIAL PRIMARY KEY, user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, model VARCHAR(100), request_duration_ms FLOAT, tokens_input INTEGER DEFAULT 0, tokens_output INTEGER DEFAULT 0, total_tokens INTEGER DEFAULT 0, success BOOLEAN DEFAULT TRUE, error_type VARCHAR(100), created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ) """) # Rate limiting table await conn.execute(""" CREATE TABLE IF NOT EXISTS rate_limits ( user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, request_count INTEGER DEFAULT 0, window_start TIMESTAMP WITH TIME ZONE DEFAULT NOW(), updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ) """) logger.info("Database tables created/verified") async def _create_indexes(self) -> None: async with self._acquire() as conn: await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_user_id_created_at ON messages(user_id, created_at DESC) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_messages_user_id_summarized ON messages(user_id, is_summarized, created_at DESC) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_metrics_user_created ON metrics(user_id, created_at DESC) """) await conn.execute(""" CREATE INDEX IF NOT EXISTS idx_metrics_created_at ON metrics(created_at DESC) """) # ═══════════════════════════════════════════════════════════════ # Users # ═══════════════════════════════════════════════════════════════ async def upsert_user( self, user_id: int, username: Optional[str], first_name: Optional[str], last_name: Optional[str], ) -> None: async with self._acquire() as conn: await conn.execute(""" INSERT INTO users (id, username, first_name, last_name) VALUES ($1, $2, $3, $4) ON CONFLICT (id) DO UPDATE SET username = EXCLUDED.username, first_name = EXCLUDED.first_name, last_name = EXCLUDED.last_name, updated_at = NOW() """, user_id, username, first_name, last_name) async def get_user_settings(self, user_id: int) -> Dict[str, Any]: async with self._acquire() as conn: row = await conn.fetchrow(""" SELECT settings FROM users WHERE id = $1 """, user_id) return row["settings"] if row and row["settings"] else {} async def update_user_settings(self, user_id: int, settings: Dict[str, Any]) -> None: async with self._acquire() as conn: await conn.execute(""" UPDATE users SET settings = $2, updated_at = NOW() WHERE id = $1 """, user_id, settings) # ═══════════════════════════════════════════════════════════════ # Messages # ═══════════════════════════════════════════════════════════════ async def save_message( self, user_id: int, role: str, content: str, tokens_used: int = 0 ) -> None: async with self._acquire() as conn: await conn.execute(""" INSERT INTO messages (user_id, role, content, tokens_used) VALUES ($1, $2, $3, $4) """, user_id, role, content, tokens_used) async def get_messages(self, user_id: int, limit: int = 30) -> List[Dict[str, Any]]: async with self._acquire() as conn: rows = await conn.fetch(""" SELECT id, role, content, tokens_used, created_at FROM messages WHERE user_id = $1 AND is_summarized = FALSE ORDER BY created_at DESC LIMIT $2 """, user_id, limit) return [ { "id": r["id"], "role": r["role"], "content": r["content"], "tokens_used": r["tokens_used"], "created_at": r["created_at"], } for r in reversed(rows) ] async def get_messages_with_token_budget( self, user_id: int, max_tokens: int ) -> List[Dict[str, Any]]: """Возвращает сообщения, укладывающиеся в бюджет токенов.""" async with self._acquire() as conn: rows = await conn.fetch(""" SELECT id, role, content, tokens_used, created_at FROM messages WHERE user_id = $1 AND is_summarized = FALSE ORDER BY created_at DESC """, user_id) result = [] total_tokens = 0 for r in rows: msg_tokens = r["tokens_used"] or len(r["content"].split()) * 2 if total_tokens + msg_tokens > max_tokens and result: break total_tokens += msg_tokens result.insert(0, { "id": r["id"], "role": r["role"], "content": r["content"], "tokens_used": r["tokens_used"], "created_at": r["created_at"], }) return result # ═══════════════════════════════════════════════════════════════ # Summaries # ═══════════════════════════════════════════════════════════════ async def get_summary(self, user_id: int) -> Optional[str]: async with self._acquire() as conn: row = await conn.fetchrow(""" SELECT summary FROM summaries WHERE user_id = $1 """, user_id) return row["summary"] if row else None async def save_summary(self, user_id: int, summary: str, message_count: int) -> None: async with self._acquire() as conn: await conn.execute(""" INSERT INTO summaries (user_id, summary, message_count, updated_at) VALUES ($1, $2, $3, NOW()) ON CONFLICT (user_id) DO UPDATE SET summary = EXCLUDED.summary, message_count = summaries.message_count + EXCLUDED.message_count, updated_at = NOW() """, user_id, summary, message_count) async def mark_summarized(self, user_id: int, cutoff_id: int) -> None: async with self._acquire() as conn: await conn.execute(""" UPDATE messages SET is_summarized = TRUE WHERE user_id = $1 AND id <= $2 """, user_id, cutoff_id) async def get_oldest_unsummarized(self, user_id: int, limit: int) -> List[Dict[str, Any]]: async with self._acquire() as conn: rows = await conn.fetch(""" SELECT id, role, content FROM messages WHERE user_id = $1 AND is_summarized = FALSE ORDER BY created_at ASC LIMIT $2 """, user_id, limit) return [{"id": r["id"], "role": r["role"], "content": r["content"]} for r in rows] async def count_unsummarized(self, user_id: int) -> int: async with self._acquire() as conn: val = await conn.fetchval(""" SELECT COUNT(*) FROM messages WHERE user_id = $1 AND is_summarized = FALSE """, user_id) return val or 0 # ═══════════════════════════════════════════════════════════════ # Clear History # ═══════════════════════════════════════════════════════════════ async def clear_history(self, user_id: int) -> int: async with self._acquire() as conn: async with conn.transaction(): result = await conn.execute(""" DELETE FROM messages WHERE user_id = $1 """, user_id) await conn.execute(""" DELETE FROM summaries WHERE user_id = $1 """, user_id) try: count = int(result.split()[-1]) except (ValueError, IndexError): count = 0 logger.info("Cleared %d messages and summary for user %s", count, user_id) return count # ═══════════════════════════════════════════════════════════════ # Stats # ═══════════════════════════════════════════════════════════════ async def get_stats(self, user_id: int) -> Dict[str, Any]: async with self._acquire() as conn: user_count = await conn.fetchval("SELECT COUNT(*) FROM users") msg_count = await conn.fetchval( "SELECT COUNT(*) FROM messages WHERE user_id = $1", user_id ) total_msg_count = await conn.fetchval("SELECT COUNT(*) FROM messages") summary = await conn.fetchval( "SELECT message_count FROM summaries WHERE user_id = $1", user_id ) total_tokens = await conn.fetchval(""" SELECT COALESCE(SUM(total_tokens), 0) FROM metrics WHERE user_id = $1 """, user_id) avg_latency = await conn.fetchval(""" SELECT COALESCE(AVG(request_duration_ms), 0) FROM metrics WHERE user_id = $1 AND success = TRUE """, user_id) return { "total_users": user_count, "user_messages": msg_count, "total_messages": total_msg_count, "summarized_messages": summary or 0, "total_tokens_used": int(total_tokens), "avg_latency_ms": round(avg_latency, 1) if avg_latency else 0, } # ═══════════════════════════════════════════════════════════════ # Metrics # ═══════════════════════════════════════════════════════════════ async def save_metric( self, user_id: Optional[int], model: str, duration_ms: float, tokens_input: int = 0, tokens_output: int = 0, success: bool = True, error_type: Optional[str] = None, ) -> None: async with self._acquire() as conn: await conn.execute(""" INSERT INTO metrics (user_id, model, request_duration_ms, tokens_input, tokens_output, total_tokens, success, error_type) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) """, user_id, model, duration_ms, tokens_input, tokens_output, tokens_input + tokens_output, success, error_type) async def get_user_metrics(self, user_id: int, limit: int = 50) -> List[Dict[str, Any]]: async with self._acquire() as conn: rows = await conn.fetch(""" SELECT model, request_duration_ms, total_tokens, success, error_type, created_at FROM metrics WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 """, user_id, limit) return [dict(r) for r in rows] # ═══════════════════════════════════════════════════════════════ # Rate Limiting # ═══════════════════════════════════════════════════════════════ async def check_rate_limit(self, user_id: int) -> tuple[bool, int, float]: """Returns (allowed, remaining_requests, reset_in_seconds).""" if not config.RATE_LIMIT_ENABLED: return True, 999, 0.0 async with self._acquire() as conn: async with conn.transaction(): row = await conn.fetchrow(""" SELECT request_count, window_start FROM rate_limits WHERE user_id = $1 FOR UPDATE """, user_id) now = time.time() window_duration = 60 # 1 minute if not row: await conn.execute(""" INSERT INTO rate_limits (user_id, request_count, window_start) VALUES ($1, 1, NOW()) """, user_id) return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration window_start_ts = row["window_start"].timestamp() if now - window_start_ts >= window_duration: # Window expired, reset await conn.execute(""" UPDATE rate_limits SET request_count = 1, window_start = NOW(), updated_at = NOW() WHERE user_id = $1 """, user_id) return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration if row["request_count"] >= config.RATE_LIMIT_REQUESTS_PER_MINUTE: reset_in = window_duration - (now - window_start_ts) return False, 0, reset_in await conn.execute(""" UPDATE rate_limits SET request_count = request_count + 1, updated_at = NOW() WHERE user_id = $1 """, user_id) remaining = config.RATE_LIMIT_REQUESTS_PER_MINUTE - row["request_count"] - 1 reset_in = window_duration - (now - window_start_ts) return True, remaining, reset_in async def reset_rate_limit(self, user_id: int) -> None: async with self._acquire() as conn: await conn.execute(""" DELETE FROM rate_limits WHERE user_id = $1 """, user_id) db = Database()