File size: 2,450 Bytes
e3804d7
 
 
 
 
 
 
 
 
 
 
b2ef5d6
e3804d7
 
 
 
 
 
 
2676b7a
 
 
e3804d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2676b7a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, scoped_session
from .models import Base

# Load .env file to ensure DATABASE_URL is available
from dotenv import load_dotenv
load_dotenv()

# Database URL configuration
# Priority: DATABASE_URL env var (PostgreSQL) > SQLite fallback
DATABASE_URL = os.getenv('DATABASE_URL', '').strip() or None

if DATABASE_URL:
    # Production: Use PostgreSQL from environment variable
    # Handle postgres:// vs postgresql:// (some providers use old format)
    if DATABASE_URL.startswith('postgres://'):
        DATABASE_URL = DATABASE_URL.replace('postgres://', 'postgresql://', 1)
    
    # pool_pre_ping=True helps with "Connection timed out" errors by checking if the connection is alive
    # pool_recycle=3600 ensures connections are refreshed every hour
    engine = create_engine(DATABASE_URL, pool_pre_ping=True, pool_recycle=3600)
else:
    # Local development: Use SQLite
    # Store DB file in /data if available (writable volume), or project root locally
    if os.path.exists('/data') and os.access('/data', os.W_OK):
        DB_NAME = '/data/class_data.db'
    else:
        ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
        DB_NAME = os.path.join(ROOT, 'class_data.db')
    
    DATABASE_URL = f"sqlite:///{DB_NAME}"
    engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
ScopedSession = scoped_session(SessionLocal)

def init_db():
    """Initialize the database by creating all tables defined in models.py."""
    Base.metadata.create_all(bind=engine)
    
    # SQLite-specific migration: Add new columns if they don't exist
    # Only run this for SQLite databases
    if 'sqlite' in str(engine.url):
        with engine.connect() as conn:
            # Check schedule table columns
            result = conn.execute(text("PRAGMA table_info(schedule)"))
            columns = [row.name for row in result]
            
            if 'instructor' not in columns:
                conn.execute(text("ALTER TABLE schedule ADD COLUMN instructor TEXT"))
            if 'note' not in columns:
                conn.execute(text("ALTER TABLE schedule ADD COLUMN note TEXT"))
            conn.commit()

def get_db():
    """Provide a new database session. Caller is responsible for closing it."""
    return SessionLocal()