import time import logging import uuid from typing import Optional, List import base64 import hashlib import json from cryptography.fernet import Fernet from sqlalchemy import select, delete, update from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, Index log = logging.getLogger(__name__) #################### # DB MODEL #################### class OAuthSession(Base): __tablename__ = 'oauth_session' id = Column(Text, primary_key=True, unique=True) user_id = Column(Text, nullable=False) provider = Column(Text, nullable=False) token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token expires_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) # Add indexes for better performance __table_args__ = ( Index('idx_oauth_session_user_id', 'user_id'), Index('idx_oauth_session_expires_at', 'expires_at'), Index('idx_oauth_session_user_provider', 'user_id', 'provider'), ) class OAuthSessionModel(BaseModel): id: str user_id: str provider: str token: dict expires_at: int # timestamp in epoch created_at: int # timestamp in epoch updated_at: int # timestamp in epoch model_config = ConfigDict(from_attributes=True) #################### # Forms #################### class OAuthSessionResponse(BaseModel): id: str user_id: str provider: str expires_at: int class OAuthSessionTable: def __init__(self): self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY if not self.encryption_key: raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set') # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) if len(self.encryption_key) != 44: key_bytes = hashlib.sha256(self.encryption_key.encode()).digest() self.encryption_key = base64.urlsafe_b64encode(key_bytes) else: self.encryption_key = self.encryption_key.encode() try: self.fernet = Fernet(self.encryption_key) except Exception as e: log.error(f'Error initializing Fernet with provided key: {e}') raise def _encrypt_token(self, token) -> str: """Encrypt OAuth tokens for storage""" try: token_json = json.dumps(token) encrypted = self.fernet.encrypt(token_json.encode()).decode() return encrypted except Exception as e: log.error(f'Error encrypting tokens: {e}') raise def _decrypt_token(self, token: str): """Decrypt OAuth tokens from storage""" try: decrypted = self.fernet.decrypt(token.encode()).decode() return json.loads(decrypted) except Exception as e: log.error(f'Error decrypting tokens: {type(e).__name__}: {e}') raise async def create_session( self, user_id: str, provider: str, token: dict, db: Optional[AsyncSession] = None, ) -> Optional[OAuthSessionModel]: """Create a new OAuth session""" try: async with get_async_db_context(db) as db: current_time = int(time.time()) id = str(uuid.uuid4()) result = OAuthSession( **{ 'id': id, 'user_id': user_id, 'provider': provider, 'token': self._encrypt_token(token), 'expires_at': token.get('expires_at') or int(time.time() + 3600), 'created_at': current_time, 'updated_at': current_time, } ) db.add(result) await db.commit() await db.refresh(result) if result: # Make a copy of the model data before closing session model = OAuthSessionModel( id=result.id, user_id=result.user_id, provider=result.provider, token=token, # Return decrypted token expires_at=result.expires_at, created_at=result.created_at, updated_at=result.updated_at, ) return model else: return None except Exception as e: log.error(f'Error creating OAuth session: {e}') return None async def get_session_by_id( self, session_id: str, db: Optional[AsyncSession] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: async with get_async_db_context(db) as db: result = await db.execute(select(OAuthSession).filter_by(id=session_id)) session = result.scalars().first() if session: return OAuthSessionModel( id=session.id, user_id=session.user_id, provider=session.provider, token=self._decrypt_token(session.token), expires_at=session.expires_at, created_at=session.created_at, updated_at=session.updated_at, ) return None except Exception as e: log.error(f'Error getting OAuth session by ID: {e}') return None async def get_session_by_id_and_user_id( self, session_id: str, user_id: str, db: Optional[AsyncSession] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID and user ID""" try: async with get_async_db_context(db) as db: result = await db.execute(select(OAuthSession).filter_by(id=session_id, user_id=user_id)) session = result.scalars().first() if session: return OAuthSessionModel( id=session.id, user_id=session.user_id, provider=session.provider, token=self._decrypt_token(session.token), expires_at=session.expires_at, created_at=session.created_at, updated_at=session.updated_at, ) return None except Exception as e: log.error(f'Error getting OAuth session by ID: {e}') return None async def get_session_by_provider_and_user_id( self, provider: str, user_id: str, db: Optional[AsyncSession] = None ) -> Optional[OAuthSessionModel]: """Get OAuth session by provider and user ID""" try: async with get_async_db_context(db) as db: result = await db.execute( select(OAuthSession) .filter_by(provider=provider, user_id=user_id) .order_by(OAuthSession.created_at.desc()) ) session = result.scalars().first() if session: return OAuthSessionModel( id=session.id, user_id=session.user_id, provider=session.provider, token=self._decrypt_token(session.token), expires_at=session.expires_at, created_at=session.created_at, updated_at=session.updated_at, ) return None except Exception as e: log.error(f'Error getting OAuth session by provider and user ID: {e}') return None async def get_sessions_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: async with get_async_db_context(db) as db: result = await db.execute(select(OAuthSession).filter_by(user_id=user_id)) sessions = result.scalars().all() results = [] for session in sessions: try: results.append( OAuthSessionModel( id=session.id, user_id=session.user_id, provider=session.provider, token=self._decrypt_token(session.token), expires_at=session.expires_at, created_at=session.created_at, updated_at=session.updated_at, ) ) except Exception as e: log.warning( f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}' ) await db.execute(delete(OAuthSession).filter_by(id=session.id)) await db.commit() return results except Exception as e: log.error(f'Error getting OAuth sessions by user ID: {e}') return [] async def update_session_by_id( self, session_id: str, token: dict, db: Optional[AsyncSession] = None ) -> Optional[OAuthSessionModel]: """Update OAuth session tokens""" try: async with get_async_db_context(db) as db: current_time = int(time.time()) await db.execute( update(OAuthSession) .filter_by(id=session_id) .values( token=self._encrypt_token(token), expires_at=token.get('expires_at') or int(time.time() + 3600), updated_at=current_time, ) ) await db.commit() result = await db.execute(select(OAuthSession).filter_by(id=session_id)) session = result.scalars().first() if session: return OAuthSessionModel( id=session.id, user_id=session.user_id, provider=session.provider, token=self._decrypt_token(session.token), expires_at=session.expires_at, created_at=session.created_at, updated_at=session.updated_at, ) return None except Exception as e: log.error(f'Error updating OAuth session tokens: {e}') return None async def delete_session_by_id(self, session_id: str, db: Optional[AsyncSession] = None) -> bool: """Delete an OAuth session""" try: async with get_async_db_context(db) as db: result = await db.execute(delete(OAuthSession).filter_by(id=session_id)) await db.commit() return result.rowcount > 0 except Exception as e: log.error(f'Error deleting OAuth session: {e}') return False async def delete_sessions_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: """Delete all OAuth sessions for a user""" try: async with get_async_db_context(db) as db: await db.execute(delete(OAuthSession).filter_by(user_id=user_id)) await db.commit() return True except Exception as e: log.error(f'Error deleting OAuth sessions by user ID: {e}') return False async def delete_sessions_by_user_id_and_provider( self, user_id: str, provider: str, db: Optional[AsyncSession] = None ) -> bool: """Delete all OAuth sessions for a specific user and provider""" try: async with get_async_db_context(db) as db: result = await db.execute(delete(OAuthSession).filter_by(user_id=user_id, provider=provider)) await db.commit() return result.rowcount > 0 except Exception as e: log.error(f'Error deleting OAuth sessions for user {user_id} and provider {provider}: {e}') return False async def delete_sessions_by_provider(self, provider: str, db: Optional[AsyncSession] = None) -> bool: """Delete all OAuth sessions for a provider""" try: async with get_async_db_context(db) as db: await db.execute(delete(OAuthSession).filter_by(provider=provider)) await db.commit() return True except Exception as e: log.error(f'Error deleting OAuth sessions by provider {provider}: {e}') return False OAuthSessions = OAuthSessionTable()