MedRAG / app /db /database.py
hetsheta's picture
Migrate to PostgreSQL and ignore start-dev scripts
6782c88
Raw
History Blame Contribute Delete
4.66 kB
"""
Main application database (shared across all users).
- users table
- documents table (Qdrant vector registry)
- conversations table (user chat sessions)
- messages table (chat message history)
"""
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Boolean, DateTime, Text, Integer, text
from datetime import datetime, timezone
from loguru import logger
from app.core.config import get_settings
settings = get_settings()
# Ensure data directory exists
os.makedirs("./data", exist_ok=True)
async def create_db_if_not_exists(database_url: str):
# Only try to create database for postgresql URLs
if not database_url.startswith("postgresql"):
return
from urllib.parse import urlparse
parsed = urlparse(database_url)
db_name = parsed.path.lstrip("/")
# Connection URL to default 'postgres' database
base_url = f"{parsed.scheme}://{parsed.netloc}/postgres"
# Create a temporary engine connecting to the default 'postgres' DB outside of transaction
temp_engine = create_async_engine(base_url, isolation_level="AUTOCOMMIT")
try:
async with temp_engine.connect() as conn:
# Check if the database already exists
result = await conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
{"dbname": db_name}
)
exists = result.scalar()
if not exists:
logger.info(f"Database '{db_name}' does not exist. Creating it...")
await conn.execute(text(f"CREATE DATABASE {db_name}"))
logger.info(f"Database '{db_name}' created successfully.")
except Exception as e:
logger.warning(f"Could not check/create database '{db_name}' automatically: {e}")
finally:
await temp_engine.dispose()
engine = create_async_engine(settings.database_url, echo=settings.debug)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
full_name: Mapped[str] = mapped_column(String(255), nullable=False)
role: Mapped[str] = mapped_column(String(50), default="user")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
class DocumentRecord(Base):
__tablename__ = "documents"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
filename: Mapped[str] = mapped_column(String(512), nullable=False)
title: Mapped[str] = mapped_column(String(512), nullable=False)
source: Mapped[str | None] = mapped_column(String(512))
chunk_count: Mapped[int] = mapped_column(Integer, default=0)
uploaded_by: Mapped[str] = mapped_column(String(36), nullable=False)
conversation_id: Mapped[str] = mapped_column(String(36), nullable=False, default="")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
status: Mapped[str] = mapped_column(String(50), default="processing")
error_message: Mapped[str | None] = mapped_column(Text)
async def init_db() -> None:
# Auto-create the target PostgreSQL database if it does not exist
await create_db_if_not_exists(settings.database_url)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Local import to prevent circular dependency
from app.db.chat_db import ChatBase
await conn.run_sync(ChatBase.metadata.create_all)
# Use SQLAlchemy inspector to check columns database-agnostically
def get_columns(connection):
from sqlalchemy import inspect
inspector = inspect(connection)
return [col["name"] for col in inspector.get_columns("documents")]
column_names = await conn.run_sync(get_columns)
if "conversation_id" not in column_names:
await conn.execute(text("ALTER TABLE documents ADD COLUMN conversation_id VARCHAR(36) NOT NULL DEFAULT ''"))
async def get_db(): # type: ignore[return]
async with AsyncSessionLocal() as session:
yield session