Spaces:
Sleeping
Sleeping
| import asyncpg | |
| import logging | |
| import time | |
| from typing import Optional, List, Dict, Any | |
| from config import config | |
| logger = logging.getLogger(__name__) | |
| class Database: | |
| def __init__(self) -> None: | |
| self.pool: Optional[asyncpg.Pool] = None | |
| async def connect(self) -> None: | |
| self.pool = await asyncpg.create_pool( | |
| dsn=config.DATABASE_URL, | |
| min_size=1, | |
| max_size=5, # ΡΠ²Π΅Π»ΠΈΡΠ΅Π½ΠΎ Π΄Π»Ρ ΠΌΠ°ΡΡΡΠ°Π±ΠΈΡΡΠ΅ΠΌΠΎΡΡΠΈ | |
| command_timeout=60, | |
| server_settings={ | |
| "jit": "off", # ΠΎΡΠΊΠ»ΡΡΠ°Π΅ΠΌ JIT Π΄Π»Ρ ΡΡΠ°Π±ΠΈΠ»ΡΠ½ΠΎΡΡΠΈ | |
| "application_name": "glm_bot", | |
| }, | |
| ) | |
| logger.info("Database pool created (max_size=5)") | |
| await self._create_tables() | |
| await self._create_indexes() | |
| async def disconnect(self) -> None: | |
| if self.pool: | |
| await self.pool.close() | |
| logger.info("Database pool closed") | |
| def _acquire(self): | |
| if self.pool is None: | |
| raise RuntimeError("Database not connected. Call connect() first.") | |
| return self.pool.acquire() | |
| async def _create_tables(self) -> None: | |
| async with self._acquire() as conn: | |
| # Users table | |
| await conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id BIGINT PRIMARY KEY, | |
| username VARCHAR(255), | |
| first_name VARCHAR(255), | |
| last_name VARCHAR(255), | |
| created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), | |
| updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), | |
| settings JSONB DEFAULT '{}'::jsonb | |
| ) | |
| """) | |
| # Messages table | |
| await conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| id SERIAL PRIMARY KEY, | |
| user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, | |
| role VARCHAR(20) NOT NULL CHECK (role IN ('user', 'assistant', 'system')), | |
| content TEXT NOT NULL, | |
| tokens_used INTEGER DEFAULT 0, | |
| is_summarized BOOLEAN DEFAULT FALSE, | |
| created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() | |
| ) | |
| """) | |
| # Summaries table | |
| await conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS summaries ( | |
| id SERIAL PRIMARY KEY, | |
| user_id BIGINT NOT NULL UNIQUE REFERENCES users(id) ON DELETE CASCADE, | |
| summary TEXT NOT NULL, | |
| message_count INTEGER NOT NULL DEFAULT 0, | |
| created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), | |
| updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() | |
| ) | |
| """) | |
| # Metrics table β Π΄Π»Ρ ΠΌΠΎΠ½ΠΈΡΠΎΡΠΈΠ½Π³Π° | |
| await conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS metrics ( | |
| id SERIAL PRIMARY KEY, | |
| user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, | |
| model VARCHAR(100), | |
| request_duration_ms FLOAT, | |
| tokens_input INTEGER DEFAULT 0, | |
| tokens_output INTEGER DEFAULT 0, | |
| total_tokens INTEGER DEFAULT 0, | |
| success BOOLEAN DEFAULT TRUE, | |
| error_type VARCHAR(100), | |
| created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() | |
| ) | |
| """) | |
| # Rate limiting table | |
| await conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS rate_limits ( | |
| user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE, | |
| request_count INTEGER DEFAULT 0, | |
| window_start TIMESTAMP WITH TIME ZONE DEFAULT NOW(), | |
| updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() | |
| ) | |
| """) | |
| logger.info("Database tables created/verified") | |
| async def _create_indexes(self) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_messages_user_id_created_at | |
| ON messages(user_id, created_at DESC) | |
| """) | |
| await conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_messages_user_id_summarized | |
| ON messages(user_id, is_summarized, created_at DESC) | |
| """) | |
| await conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_metrics_user_created | |
| ON metrics(user_id, created_at DESC) | |
| """) | |
| await conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_metrics_created_at | |
| ON metrics(created_at DESC) | |
| """) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Users | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def upsert_user( | |
| self, | |
| user_id: int, | |
| username: Optional[str], | |
| first_name: Optional[str], | |
| last_name: Optional[str], | |
| ) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| INSERT INTO users (id, username, first_name, last_name) | |
| VALUES ($1, $2, $3, $4) | |
| ON CONFLICT (id) DO UPDATE SET | |
| username = EXCLUDED.username, | |
| first_name = EXCLUDED.first_name, | |
| last_name = EXCLUDED.last_name, | |
| updated_at = NOW() | |
| """, user_id, username, first_name, last_name) | |
| async def get_user_settings(self, user_id: int) -> Dict[str, Any]: | |
| async with self._acquire() as conn: | |
| row = await conn.fetchrow(""" | |
| SELECT settings FROM users WHERE id = $1 | |
| """, user_id) | |
| return row["settings"] if row and row["settings"] else {} | |
| async def update_user_settings(self, user_id: int, settings: Dict[str, Any]) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| UPDATE users SET settings = $2, updated_at = NOW() WHERE id = $1 | |
| """, user_id, settings) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Messages | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def save_message( | |
| self, user_id: int, role: str, content: str, tokens_used: int = 0 | |
| ) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| INSERT INTO messages (user_id, role, content, tokens_used) | |
| VALUES ($1, $2, $3, $4) | |
| """, user_id, role, content, tokens_used) | |
| async def get_messages(self, user_id: int, limit: int = 30) -> List[Dict[str, Any]]: | |
| async with self._acquire() as conn: | |
| rows = await conn.fetch(""" | |
| SELECT id, role, content, tokens_used, created_at | |
| FROM messages | |
| WHERE user_id = $1 AND is_summarized = FALSE | |
| ORDER BY created_at DESC | |
| LIMIT $2 | |
| """, user_id, limit) | |
| return [ | |
| { | |
| "id": r["id"], | |
| "role": r["role"], | |
| "content": r["content"], | |
| "tokens_used": r["tokens_used"], | |
| "created_at": r["created_at"], | |
| } | |
| for r in reversed(rows) | |
| ] | |
| async def get_messages_with_token_budget( | |
| self, user_id: int, max_tokens: int | |
| ) -> List[Dict[str, Any]]: | |
| """ΠΠΎΠ·Π²ΡΠ°ΡΠ°Π΅Ρ ΡΠΎΠΎΠ±ΡΠ΅Π½ΠΈΡ, ΡΠΊΠ»Π°Π΄ΡΠ²Π°ΡΡΠΈΠ΅ΡΡ Π² Π±ΡΠ΄ΠΆΠ΅Ρ ΡΠΎΠΊΠ΅Π½ΠΎΠ².""" | |
| async with self._acquire() as conn: | |
| rows = await conn.fetch(""" | |
| SELECT id, role, content, tokens_used, created_at | |
| FROM messages | |
| WHERE user_id = $1 AND is_summarized = FALSE | |
| ORDER BY created_at DESC | |
| """, user_id) | |
| result = [] | |
| total_tokens = 0 | |
| for r in rows: | |
| msg_tokens = r["tokens_used"] or len(r["content"].split()) * 2 | |
| if total_tokens + msg_tokens > max_tokens and result: | |
| break | |
| total_tokens += msg_tokens | |
| result.insert(0, { | |
| "id": r["id"], | |
| "role": r["role"], | |
| "content": r["content"], | |
| "tokens_used": r["tokens_used"], | |
| "created_at": r["created_at"], | |
| }) | |
| return result | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Summaries | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_summary(self, user_id: int) -> Optional[str]: | |
| async with self._acquire() as conn: | |
| row = await conn.fetchrow(""" | |
| SELECT summary FROM summaries WHERE user_id = $1 | |
| """, user_id) | |
| return row["summary"] if row else None | |
| async def save_summary(self, user_id: int, summary: str, message_count: int) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| INSERT INTO summaries (user_id, summary, message_count, updated_at) | |
| VALUES ($1, $2, $3, NOW()) | |
| ON CONFLICT (user_id) DO UPDATE SET | |
| summary = EXCLUDED.summary, | |
| message_count = summaries.message_count + EXCLUDED.message_count, | |
| updated_at = NOW() | |
| """, user_id, summary, message_count) | |
| async def mark_summarized(self, user_id: int, cutoff_id: int) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| UPDATE messages | |
| SET is_summarized = TRUE | |
| WHERE user_id = $1 AND id <= $2 | |
| """, user_id, cutoff_id) | |
| async def get_oldest_unsummarized(self, user_id: int, limit: int) -> List[Dict[str, Any]]: | |
| async with self._acquire() as conn: | |
| rows = await conn.fetch(""" | |
| SELECT id, role, content | |
| FROM messages | |
| WHERE user_id = $1 AND is_summarized = FALSE | |
| ORDER BY created_at ASC | |
| LIMIT $2 | |
| """, user_id, limit) | |
| return [{"id": r["id"], "role": r["role"], "content": r["content"]} for r in rows] | |
| async def count_unsummarized(self, user_id: int) -> int: | |
| async with self._acquire() as conn: | |
| val = await conn.fetchval(""" | |
| SELECT COUNT(*) FROM messages | |
| WHERE user_id = $1 AND is_summarized = FALSE | |
| """, user_id) | |
| return val or 0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Clear History | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def clear_history(self, user_id: int) -> int: | |
| async with self._acquire() as conn: | |
| async with conn.transaction(): | |
| result = await conn.execute(""" | |
| DELETE FROM messages WHERE user_id = $1 | |
| """, user_id) | |
| await conn.execute(""" | |
| DELETE FROM summaries WHERE user_id = $1 | |
| """, user_id) | |
| try: | |
| count = int(result.split()[-1]) | |
| except (ValueError, IndexError): | |
| count = 0 | |
| logger.info("Cleared %d messages and summary for user %s", count, user_id) | |
| return count | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stats | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def get_stats(self, user_id: int) -> Dict[str, Any]: | |
| async with self._acquire() as conn: | |
| user_count = await conn.fetchval("SELECT COUNT(*) FROM users") | |
| msg_count = await conn.fetchval( | |
| "SELECT COUNT(*) FROM messages WHERE user_id = $1", user_id | |
| ) | |
| total_msg_count = await conn.fetchval("SELECT COUNT(*) FROM messages") | |
| summary = await conn.fetchval( | |
| "SELECT message_count FROM summaries WHERE user_id = $1", user_id | |
| ) | |
| total_tokens = await conn.fetchval(""" | |
| SELECT COALESCE(SUM(total_tokens), 0) FROM metrics WHERE user_id = $1 | |
| """, user_id) | |
| avg_latency = await conn.fetchval(""" | |
| SELECT COALESCE(AVG(request_duration_ms), 0) | |
| FROM metrics WHERE user_id = $1 AND success = TRUE | |
| """, user_id) | |
| return { | |
| "total_users": user_count, | |
| "user_messages": msg_count, | |
| "total_messages": total_msg_count, | |
| "summarized_messages": summary or 0, | |
| "total_tokens_used": int(total_tokens), | |
| "avg_latency_ms": round(avg_latency, 1) if avg_latency else 0, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Metrics | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def save_metric( | |
| self, | |
| user_id: Optional[int], | |
| model: str, | |
| duration_ms: float, | |
| tokens_input: int = 0, | |
| tokens_output: int = 0, | |
| success: bool = True, | |
| error_type: Optional[str] = None, | |
| ) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| INSERT INTO metrics | |
| (user_id, model, request_duration_ms, tokens_input, tokens_output, | |
| total_tokens, success, error_type) | |
| VALUES ($1, $2, $3, $4, $5, $6, $7, $8) | |
| """, user_id, model, duration_ms, tokens_input, tokens_output, | |
| tokens_input + tokens_output, success, error_type) | |
| async def get_user_metrics(self, user_id: int, limit: int = 50) -> List[Dict[str, Any]]: | |
| async with self._acquire() as conn: | |
| rows = await conn.fetch(""" | |
| SELECT model, request_duration_ms, total_tokens, success, error_type, created_at | |
| FROM metrics WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 | |
| """, user_id, limit) | |
| return [dict(r) for r in rows] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Rate Limiting | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def check_rate_limit(self, user_id: int) -> tuple[bool, int, float]: | |
| """Returns (allowed, remaining_requests, reset_in_seconds).""" | |
| if not config.RATE_LIMIT_ENABLED: | |
| return True, 999, 0.0 | |
| async with self._acquire() as conn: | |
| async with conn.transaction(): | |
| row = await conn.fetchrow(""" | |
| SELECT request_count, window_start | |
| FROM rate_limits WHERE user_id = $1 | |
| FOR UPDATE | |
| """, user_id) | |
| now = time.time() | |
| window_duration = 60 # 1 minute | |
| if not row: | |
| await conn.execute(""" | |
| INSERT INTO rate_limits (user_id, request_count, window_start) | |
| VALUES ($1, 1, NOW()) | |
| """, user_id) | |
| return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration | |
| window_start_ts = row["window_start"].timestamp() | |
| if now - window_start_ts >= window_duration: | |
| # Window expired, reset | |
| await conn.execute(""" | |
| UPDATE rate_limits | |
| SET request_count = 1, window_start = NOW(), updated_at = NOW() | |
| WHERE user_id = $1 | |
| """, user_id) | |
| return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration | |
| if row["request_count"] >= config.RATE_LIMIT_REQUESTS_PER_MINUTE: | |
| reset_in = window_duration - (now - window_start_ts) | |
| return False, 0, reset_in | |
| await conn.execute(""" | |
| UPDATE rate_limits | |
| SET request_count = request_count + 1, updated_at = NOW() | |
| WHERE user_id = $1 | |
| """, user_id) | |
| remaining = config.RATE_LIMIT_REQUESTS_PER_MINUTE - row["request_count"] - 1 | |
| reset_in = window_duration - (now - window_start_ts) | |
| return True, remaining, reset_in | |
| async def reset_rate_limit(self, user_id: int) -> None: | |
| async with self._acquire() as conn: | |
| await conn.execute(""" | |
| DELETE FROM rate_limits WHERE user_id = $1 | |
| """, user_id) | |
| db = Database() | |