Spaces:
Sleeping
Sleeping
File size: 1,918 Bytes
f0c339c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | # Database connection and session management
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from typing import Generator
import os
from models import Base
from config import config
class Database:
"""Database connection manager."""
def __init__(self, database_url: str = None):
"""Initialize database connection."""
self.database_url = database_url or config.get_database_url()
self.engine = create_engine(
self.database_url,
echo=False,
connect_args={"check_same_thread": False} # SQLite specific
)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
def create_tables(self):
"""Create all database tables."""
Base.metadata.create_all(bind=self.engine)
def drop_tables(self):
"""Drop all database tables."""
Base.metadata.drop_all(bind=self.engine)
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""Get a database session with automatic cleanup."""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def get_session_direct(self) -> Session:
"""Get a database session (caller must manage lifecycle)."""
return self.SessionLocal()
# Global database instance
db = Database()
# Auto-initialize tables on import
db.create_tables()
def init_database():
"""Initialize the database and create tables."""
db.create_tables()
return db
def get_db_session() -> Session:
"""Get a database session for use in Streamlit."""
return db.get_session_direct()
|