SkyAlone / database.py
FreshPixels's picture
Upload 9 files
7e686b5 verified
Raw
History Blame Contribute Delete
18.8 kB
import asyncpg
import logging
import time
from typing import Optional, List, Dict, Any
from config import config
logger = logging.getLogger(__name__)
class Database:
def __init__(self) -> None:
self.pool: Optional[asyncpg.Pool] = None
async def connect(self) -> None:
self.pool = await asyncpg.create_pool(
dsn=config.DATABASE_URL,
min_size=1,
max_size=5, # ΡƒΠ²Π΅Π»ΠΈΡ‡Π΅Π½ΠΎ для ΠΌΠ°ΡΡˆΡ‚Π°Π±ΠΈΡ€ΡƒΠ΅ΠΌΠΎΡΡ‚ΠΈ
command_timeout=60,
server_settings={
"jit": "off", # ΠΎΡ‚ΠΊΠ»ΡŽΡ‡Π°Π΅ΠΌ JIT для ΡΡ‚Π°Π±ΠΈΠ»ΡŒΠ½ΠΎΡΡ‚ΠΈ
"application_name": "glm_bot",
},
)
logger.info("Database pool created (max_size=5)")
await self._create_tables()
await self._create_indexes()
async def disconnect(self) -> None:
if self.pool:
await self.pool.close()
logger.info("Database pool closed")
def _acquire(self):
if self.pool is None:
raise RuntimeError("Database not connected. Call connect() first.")
return self.pool.acquire()
async def _create_tables(self) -> None:
async with self._acquire() as conn:
# Users table
await conn.execute("""
CREATE TABLE IF NOT EXISTS users (
id BIGINT PRIMARY KEY,
username VARCHAR(255),
first_name VARCHAR(255),
last_name VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
settings JSONB DEFAULT '{}'::jsonb
)
""")
# Messages table
await conn.execute("""
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role VARCHAR(20) NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
content TEXT NOT NULL,
tokens_used INTEGER DEFAULT 0,
is_summarized BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
""")
# Summaries table
await conn.execute("""
CREATE TABLE IF NOT EXISTS summaries (
id SERIAL PRIMARY KEY,
user_id BIGINT NOT NULL UNIQUE REFERENCES users(id) ON DELETE CASCADE,
summary TEXT NOT NULL,
message_count INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
""")
# Metrics table β€” для ΠΌΠΎΠ½ΠΈΡ‚ΠΎΡ€ΠΈΠ½Π³Π°
await conn.execute("""
CREATE TABLE IF NOT EXISTS metrics (
id SERIAL PRIMARY KEY,
user_id BIGINT REFERENCES users(id) ON DELETE SET NULL,
model VARCHAR(100),
request_duration_ms FLOAT,
tokens_input INTEGER DEFAULT 0,
tokens_output INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
success BOOLEAN DEFAULT TRUE,
error_type VARCHAR(100),
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
""")
# Rate limiting table
await conn.execute("""
CREATE TABLE IF NOT EXISTS rate_limits (
user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
request_count INTEGER DEFAULT 0,
window_start TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
)
""")
logger.info("Database tables created/verified")
async def _create_indexes(self) -> None:
async with self._acquire() as conn:
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_user_id_created_at
ON messages(user_id, created_at DESC)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_messages_user_id_summarized
ON messages(user_id, is_summarized, created_at DESC)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_metrics_user_created
ON metrics(user_id, created_at DESC)
""")
await conn.execute("""
CREATE INDEX IF NOT EXISTS idx_metrics_created_at
ON metrics(created_at DESC)
""")
# ═══════════════════════════════════════════════════════════════
# Users
# ═══════════════════════════════════════════════════════════════
async def upsert_user(
self,
user_id: int,
username: Optional[str],
first_name: Optional[str],
last_name: Optional[str],
) -> None:
async with self._acquire() as conn:
await conn.execute("""
INSERT INTO users (id, username, first_name, last_name)
VALUES ($1, $2, $3, $4)
ON CONFLICT (id) DO UPDATE SET
username = EXCLUDED.username,
first_name = EXCLUDED.first_name,
last_name = EXCLUDED.last_name,
updated_at = NOW()
""", user_id, username, first_name, last_name)
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
async with self._acquire() as conn:
row = await conn.fetchrow("""
SELECT settings FROM users WHERE id = $1
""", user_id)
return row["settings"] if row and row["settings"] else {}
async def update_user_settings(self, user_id: int, settings: Dict[str, Any]) -> None:
async with self._acquire() as conn:
await conn.execute("""
UPDATE users SET settings = $2, updated_at = NOW() WHERE id = $1
""", user_id, settings)
# ═══════════════════════════════════════════════════════════════
# Messages
# ═══════════════════════════════════════════════════════════════
async def save_message(
self, user_id: int, role: str, content: str, tokens_used: int = 0
) -> None:
async with self._acquire() as conn:
await conn.execute("""
INSERT INTO messages (user_id, role, content, tokens_used)
VALUES ($1, $2, $3, $4)
""", user_id, role, content, tokens_used)
async def get_messages(self, user_id: int, limit: int = 30) -> List[Dict[str, Any]]:
async with self._acquire() as conn:
rows = await conn.fetch("""
SELECT id, role, content, tokens_used, created_at
FROM messages
WHERE user_id = $1 AND is_summarized = FALSE
ORDER BY created_at DESC
LIMIT $2
""", user_id, limit)
return [
{
"id": r["id"],
"role": r["role"],
"content": r["content"],
"tokens_used": r["tokens_used"],
"created_at": r["created_at"],
}
for r in reversed(rows)
]
async def get_messages_with_token_budget(
self, user_id: int, max_tokens: int
) -> List[Dict[str, Any]]:
"""Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ сообщСния, ΡƒΠΊΠ»Π°Π΄Ρ‹Π²Π°ΡŽΡ‰ΠΈΠ΅ΡΡ Π² Π±ΡŽΠ΄ΠΆΠ΅Ρ‚ Ρ‚ΠΎΠΊΠ΅Π½ΠΎΠ²."""
async with self._acquire() as conn:
rows = await conn.fetch("""
SELECT id, role, content, tokens_used, created_at
FROM messages
WHERE user_id = $1 AND is_summarized = FALSE
ORDER BY created_at DESC
""", user_id)
result = []
total_tokens = 0
for r in rows:
msg_tokens = r["tokens_used"] or len(r["content"].split()) * 2
if total_tokens + msg_tokens > max_tokens and result:
break
total_tokens += msg_tokens
result.insert(0, {
"id": r["id"],
"role": r["role"],
"content": r["content"],
"tokens_used": r["tokens_used"],
"created_at": r["created_at"],
})
return result
# ═══════════════════════════════════════════════════════════════
# Summaries
# ═══════════════════════════════════════════════════════════════
async def get_summary(self, user_id: int) -> Optional[str]:
async with self._acquire() as conn:
row = await conn.fetchrow("""
SELECT summary FROM summaries WHERE user_id = $1
""", user_id)
return row["summary"] if row else None
async def save_summary(self, user_id: int, summary: str, message_count: int) -> None:
async with self._acquire() as conn:
await conn.execute("""
INSERT INTO summaries (user_id, summary, message_count, updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (user_id) DO UPDATE SET
summary = EXCLUDED.summary,
message_count = summaries.message_count + EXCLUDED.message_count,
updated_at = NOW()
""", user_id, summary, message_count)
async def mark_summarized(self, user_id: int, cutoff_id: int) -> None:
async with self._acquire() as conn:
await conn.execute("""
UPDATE messages
SET is_summarized = TRUE
WHERE user_id = $1 AND id <= $2
""", user_id, cutoff_id)
async def get_oldest_unsummarized(self, user_id: int, limit: int) -> List[Dict[str, Any]]:
async with self._acquire() as conn:
rows = await conn.fetch("""
SELECT id, role, content
FROM messages
WHERE user_id = $1 AND is_summarized = FALSE
ORDER BY created_at ASC
LIMIT $2
""", user_id, limit)
return [{"id": r["id"], "role": r["role"], "content": r["content"]} for r in rows]
async def count_unsummarized(self, user_id: int) -> int:
async with self._acquire() as conn:
val = await conn.fetchval("""
SELECT COUNT(*) FROM messages
WHERE user_id = $1 AND is_summarized = FALSE
""", user_id)
return val or 0
# ═══════════════════════════════════════════════════════════════
# Clear History
# ═══════════════════════════════════════════════════════════════
async def clear_history(self, user_id: int) -> int:
async with self._acquire() as conn:
async with conn.transaction():
result = await conn.execute("""
DELETE FROM messages WHERE user_id = $1
""", user_id)
await conn.execute("""
DELETE FROM summaries WHERE user_id = $1
""", user_id)
try:
count = int(result.split()[-1])
except (ValueError, IndexError):
count = 0
logger.info("Cleared %d messages and summary for user %s", count, user_id)
return count
# ═══════════════════════════════════════════════════════════════
# Stats
# ═══════════════════════════════════════════════════════════════
async def get_stats(self, user_id: int) -> Dict[str, Any]:
async with self._acquire() as conn:
user_count = await conn.fetchval("SELECT COUNT(*) FROM users")
msg_count = await conn.fetchval(
"SELECT COUNT(*) FROM messages WHERE user_id = $1", user_id
)
total_msg_count = await conn.fetchval("SELECT COUNT(*) FROM messages")
summary = await conn.fetchval(
"SELECT message_count FROM summaries WHERE user_id = $1", user_id
)
total_tokens = await conn.fetchval("""
SELECT COALESCE(SUM(total_tokens), 0) FROM metrics WHERE user_id = $1
""", user_id)
avg_latency = await conn.fetchval("""
SELECT COALESCE(AVG(request_duration_ms), 0)
FROM metrics WHERE user_id = $1 AND success = TRUE
""", user_id)
return {
"total_users": user_count,
"user_messages": msg_count,
"total_messages": total_msg_count,
"summarized_messages": summary or 0,
"total_tokens_used": int(total_tokens),
"avg_latency_ms": round(avg_latency, 1) if avg_latency else 0,
}
# ═══════════════════════════════════════════════════════════════
# Metrics
# ═══════════════════════════════════════════════════════════════
async def save_metric(
self,
user_id: Optional[int],
model: str,
duration_ms: float,
tokens_input: int = 0,
tokens_output: int = 0,
success: bool = True,
error_type: Optional[str] = None,
) -> None:
async with self._acquire() as conn:
await conn.execute("""
INSERT INTO metrics
(user_id, model, request_duration_ms, tokens_input, tokens_output,
total_tokens, success, error_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", user_id, model, duration_ms, tokens_input, tokens_output,
tokens_input + tokens_output, success, error_type)
async def get_user_metrics(self, user_id: int, limit: int = 50) -> List[Dict[str, Any]]:
async with self._acquire() as conn:
rows = await conn.fetch("""
SELECT model, request_duration_ms, total_tokens, success, error_type, created_at
FROM metrics WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2
""", user_id, limit)
return [dict(r) for r in rows]
# ═══════════════════════════════════════════════════════════════
# Rate Limiting
# ═══════════════════════════════════════════════════════════════
async def check_rate_limit(self, user_id: int) -> tuple[bool, int, float]:
"""Returns (allowed, remaining_requests, reset_in_seconds)."""
if not config.RATE_LIMIT_ENABLED:
return True, 999, 0.0
async with self._acquire() as conn:
async with conn.transaction():
row = await conn.fetchrow("""
SELECT request_count, window_start
FROM rate_limits WHERE user_id = $1
FOR UPDATE
""", user_id)
now = time.time()
window_duration = 60 # 1 minute
if not row:
await conn.execute("""
INSERT INTO rate_limits (user_id, request_count, window_start)
VALUES ($1, 1, NOW())
""", user_id)
return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration
window_start_ts = row["window_start"].timestamp()
if now - window_start_ts >= window_duration:
# Window expired, reset
await conn.execute("""
UPDATE rate_limits
SET request_count = 1, window_start = NOW(), updated_at = NOW()
WHERE user_id = $1
""", user_id)
return True, config.RATE_LIMIT_REQUESTS_PER_MINUTE - 1, window_duration
if row["request_count"] >= config.RATE_LIMIT_REQUESTS_PER_MINUTE:
reset_in = window_duration - (now - window_start_ts)
return False, 0, reset_in
await conn.execute("""
UPDATE rate_limits
SET request_count = request_count + 1, updated_at = NOW()
WHERE user_id = $1
""", user_id)
remaining = config.RATE_LIMIT_REQUESTS_PER_MINUTE - row["request_count"] - 1
reset_in = window_duration - (now - window_start_ts)
return True, remaining, reset_in
async def reset_rate_limit(self, user_id: int) -> None:
async with self._acquire() as conn:
await conn.execute("""
DELETE FROM rate_limits WHERE user_id = $1
""", user_id)
db = Database()