Piyush1225 commited on
Commit
c92a083
·
1 Parent(s): 2696142

fix: auto-fix postgresql:// to postgresql+psycopg2:// for SQLAlchemy 2.x

Browse files
Files changed (1) hide show
  1. adaptiveauth/core/database.py +12 -2
adaptiveauth/core/database.py CHANGED
@@ -15,6 +15,14 @@ _engine = None
15
  _SessionLocal = None
16
 
17
 
 
 
 
 
 
 
 
 
18
  def get_engine(database_url: Optional[str] = None, echo: bool = False):
19
  """Get or create database engine."""
20
  global _engine
@@ -22,6 +30,7 @@ def get_engine(database_url: Optional[str] = None, echo: bool = False):
22
  if _engine is None:
23
  settings = get_settings()
24
  url = database_url or settings.DATABASE_URL
 
25
  echo = echo or settings.DATABASE_ECHO
26
 
27
  # Configure engine based on database type
@@ -116,11 +125,12 @@ class DatabaseManager:
116
  """Get database engine."""
117
  if self._engine is None:
118
  connect_args = {}
119
- if self.database_url.startswith("sqlite"):
 
120
  connect_args["check_same_thread"] = False
121
 
122
  self._engine = create_engine(
123
- self.database_url,
124
  connect_args=connect_args,
125
  echo=self.echo,
126
  pool_pre_ping=True,
 
15
  _SessionLocal = None
16
 
17
 
18
+ def _fix_db_url(url: str) -> str:
19
+ """SQLAlchemy 2.x requires postgresql+psycopg2:// not postgresql://"""
20
+ if url.startswith("postgres://") or url.startswith("postgresql://"):
21
+ url = url.replace("postgres://", "postgresql+psycopg2://", 1)
22
+ url = url.replace("postgresql://", "postgresql+psycopg2://", 1)
23
+ return url
24
+
25
+
26
  def get_engine(database_url: Optional[str] = None, echo: bool = False):
27
  """Get or create database engine."""
28
  global _engine
 
30
  if _engine is None:
31
  settings = get_settings()
32
  url = database_url or settings.DATABASE_URL
33
+ url = _fix_db_url(url)
34
  echo = echo or settings.DATABASE_ECHO
35
 
36
  # Configure engine based on database type
 
125
  """Get database engine."""
126
  if self._engine is None:
127
  connect_args = {}
128
+ url = _fix_db_url(self.database_url)
129
+ if url.startswith("sqlite"):
130
  connect_args["check_same_thread"] = False
131
 
132
  self._engine = create_engine(
133
+ url,
134
  connect_args=connect_args,
135
  echo=self.echo,
136
  pool_pre_ping=True,