File size: 4,461 Bytes
59f2028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Database connection and initialization module

Handles:
- Database connection management
- Table creation and initialization
- Connection string configuration
- Database diagnostics
"""

import os
import pyodbc
from threading import Lock
from .models import get_table_definitions

# Database configuration
DB_SERVER = os.getenv("DB_SERVER", r"(localdb)\MSSQLLocalDB")
DB_DATABASE = os.getenv("DB_DATABASE", "AuthenticationDB1")
DB_DRIVER = os.getenv("DB_DRIVER", "ODBC Driver 17 for SQL Server")

# Build connection string
is_local = (
    DB_SERVER.lower().startswith("localhost")
    or DB_SERVER.startswith(".")
    or DB_SERVER.lower().startswith("(localdb)")
    or "\\" in DB_SERVER
)

if is_local:
    # Windows local / LocalDB using modern ODBC driver
    CONN_STR = (
        f"DRIVER={{{DB_DRIVER}}};"
        f"SERVER={DB_SERVER};"
        f"DATABASE={DB_DATABASE};"
        "Trusted_Connection=yes;"
        "TrustServerCertificate=yes;"
    )
else:
    # Remote SQL auth
    CONN_STR = (
        f"DRIVER={{{DB_DRIVER}}};"
        f"SERVER={DB_SERVER};DATABASE={DB_DATABASE};"
        f"UID={os.getenv('DB_USER')};PWD={os.getenv('DB_PASSWORD')};"
        "Encrypt=yes;TrustServerCertificate=yes;"
    )

# Database initialization tracking
_db_init_done = False
_db_init_lock = Lock()


def get_db_connection():
    """
    Create a database connection with short timeout
    
    Raises:
        RuntimeError: If DB credentials are missing for remote connections
        pyodbc.Error: If connection fails
    """
    if "Trusted_Connection=yes" not in CONN_STR:
        if not os.getenv("DB_USER") or not os.getenv("DB_PASSWORD"):
            raise RuntimeError("DB_USER/DB_PASSWORD are not set in the environment.")
    return pyodbc.connect(CONN_STR, timeout=5)


def init_db():
    """
    Create database tables if they do not exist
    
    Creates:
    - Users table for authentication
    - BlacklistedTokens table for token management  
    - RefreshTokens table for refresh token storage
    """
    conn = get_db_connection()
    cur = conn.cursor()
    
    # Get table definitions
    tables = get_table_definitions()
    
    # Create each table
    for table_name, sql in tables.items():
        cur.execute(sql)
    
    conn.commit()
    conn.close()


def ensure_database_initialized():
    """
    Ensure database is initialized (thread-safe)
    
    Call this from Flask app startup to initialize database once.
    Controlled by RUN_INIT_DB environment variable.
    """
    global _db_init_done
    should_init = os.getenv("RUN_INIT_DB", "0") == "1"
    
    if should_init and not _db_init_done:
        with _db_init_lock:
            if not _db_init_done:
                try:
                    init_db()
                    print("? Database initialized successfully")
                    return True
                except Exception as e:
                    print(f"? Database initialization failed: {e}")
                    raise
                finally:
                    _db_init_done = True
    
    return _db_init_done


def get_database_info():
    """
    Get database diagnostic information (admin only)
    
    Returns safe diagnostic information without exposing credentials.
    """
    info = {}
    
    # Get available drivers
    try:
        info["drivers_found"] = pyodbc.drivers()
    except Exception as e:
        info["drivers_found_error"] = str(e)

    # Safe database information
    info["database_name"] = DB_DATABASE
    info["server_type"] = "LocalDB" if is_local else "Remote"
    
    # Test connection
    try:
        conn = get_db_connection()
        conn.close()
        info["connection_status"] = "ok"
    except Exception as e:
        info["connection_status"] = "error"
        info["error_type"] = type(e).__name__
    
    return info


def test_database_connection():
    """
    Test database connection and return status
    
    Returns:
        tuple: (success: bool, message: str)
    """
    try:
        conn = get_db_connection()
        
        # Test basic query
        cur = conn.cursor()
        cur.execute("SELECT 1")
        result = cur.fetchone()
        
        conn.close()
        
        if result and result[0] == 1:
            return True, "Database connection successful"
        else:
            return False, "Database query failed"
            
    except Exception as e:
        return False, f"Database connection failed: {str(e)}"