| """ |
| Main application database (shared across all users). |
| - users table |
| - documents table (Qdrant vector registry) |
| - conversations table (user chat sessions) |
| - messages table (chat message history) |
| """ |
| import os |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker |
| from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column |
| from sqlalchemy import String, Boolean, DateTime, Text, Integer, text |
| from datetime import datetime, timezone |
| from loguru import logger |
|
|
| from app.core.config import get_settings |
|
|
| settings = get_settings() |
|
|
| |
| os.makedirs("./data", exist_ok=True) |
|
|
|
|
| async def create_db_if_not_exists(database_url: str): |
| |
| if not database_url.startswith("postgresql"): |
| return |
| |
| from urllib.parse import urlparse |
| parsed = urlparse(database_url) |
| db_name = parsed.path.lstrip("/") |
| |
| |
| base_url = f"{parsed.scheme}://{parsed.netloc}/postgres" |
| |
| |
| temp_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT") |
| try: |
| async with temp_engine.connect() as conn: |
| |
| result = await conn.execute( |
| text("SELECT 1 FROM pg_database WHERE datname = :dbname"), |
| {"dbname": db_name} |
| ) |
| exists = result.scalar() |
| if not exists: |
| logger.info(f"Database '{db_name}' does not exist. Creating it...") |
| await conn.execute(text(f"CREATE DATABASE {db_name}")) |
| logger.info(f"Database '{db_name}' created successfully.") |
| except Exception as e: |
| logger.warning(f"Could not check/create database '{db_name}' automatically: {e}") |
| finally: |
| await temp_engine.dispose() |
|
|
|
|
| engine = create_async_engine(settings.database_url, echo=settings.debug) |
| AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) |
|
|
|
|
| class Base(DeclarativeBase): |
| pass |
|
|
|
|
| class User(Base): |
| __tablename__ = "users" |
|
|
| id: Mapped[str] = mapped_column(String(36), primary_key=True) |
| email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False) |
| hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) |
| full_name: Mapped[str] = mapped_column(String(255), nullable=False) |
| role: Mapped[str] = mapped_column(String(50), default="user") |
| is_active: Mapped[bool] = mapped_column(Boolean, default=True) |
| created_at: Mapped[datetime] = mapped_column( |
| DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) |
| ) |
|
|
|
|
| class DocumentRecord(Base): |
| __tablename__ = "documents" |
|
|
| id: Mapped[str] = mapped_column(String(36), primary_key=True) |
| filename: Mapped[str] = mapped_column(String(512), nullable=False) |
| title: Mapped[str] = mapped_column(String(512), nullable=False) |
| source: Mapped[str | None] = mapped_column(String(512)) |
| chunk_count: Mapped[int] = mapped_column(Integer, default=0) |
| uploaded_by: Mapped[str] = mapped_column(String(36), nullable=False) |
| conversation_id: Mapped[str] = mapped_column(String(36), nullable=False, default="") |
| created_at: Mapped[datetime] = mapped_column( |
| DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) |
| ) |
| status: Mapped[str] = mapped_column(String(50), default="processing") |
| error_message: Mapped[str | None] = mapped_column(Text) |
|
|
|
|
| async def init_db() -> None: |
| |
| await create_db_if_not_exists(settings.database_url) |
|
|
| async with engine.begin() as conn: |
| await conn.run_sync(Base.metadata.create_all) |
| |
| |
| from app.db.chat_db import ChatBase |
| await conn.run_sync(ChatBase.metadata.create_all) |
| |
| |
| def get_columns(connection): |
| from sqlalchemy import inspect |
| inspector = inspect(connection) |
| return [col["name"] for col in inspector.get_columns("documents")] |
| |
| column_names = await conn.run_sync(get_columns) |
| if "conversation_id" not in column_names: |
| await conn.execute(text("ALTER TABLE documents ADD COLUMN conversation_id VARCHAR(36) NOT NULL DEFAULT ''")) |
|
|
|
|
| async def get_db(): |
| async with AsyncSessionLocal() as session: |
| yield session |
|
|