apigateway / core /user_store_adapter.py
jebin2's picture
Replace custom auth_service with google-auth-service library
e6ec780
from typing import Any, Optional
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from google_auth_service.user_store import BaseUserStore
from google_auth_service.google_provider import GoogleUserInfo
from core.database import async_session_maker
from core.models import User
import uuid
import logging
logger = logging.getLogger(__name__)
class SQLAlchemyUserStore(BaseUserStore):
"""
Adapter to allow GoogleAuth library to use SQLAlchemy models.
"""
async def get(self, user_id: str) -> Optional[User]:
async with async_session_maker() as db:
query = select(User).where(User.user_id == user_id)
result = await db.execute(query)
return result.scalar_one_or_none()
async def save(self, google_info: GoogleUserInfo) -> User:
async with async_session_maker() as db:
query = select(User).where(User.email == google_info.email)
result = await db.execute(query)
user = result.scalar_one_or_none()
if user:
# Update existing
if not user.google_id:
user.google_id = google_info.google_id
user.name = google_info.name
user.profile_picture = google_info.picture
user.last_used_at = datetime.utcnow()
else:
# Create new
user = User(
user_id="usr_" + str(uuid.uuid4()),
email=google_info.email,
google_id=google_info.google_id,
name=google_info.name,
profile_picture=google_info.picture,
credits=0, # Business logic
token_version=1
)
db.add(user)
logger.info(f"New user created: {user.email}")
await db.commit()
await db.refresh(user)
return user
async def get_token_version(self, user_id: str) -> Optional[int]:
async with async_session_maker() as db:
query = select(User.token_version).where(User.user_id == user_id)
result = await db.execute(query)
return result.scalar_one_or_none()
async def invalidate_token(self, user_id: str) -> None:
async with async_session_maker() as db:
query = select(User).where(User.user_id == user_id)
result = await db.execute(query)
user = result.scalar_one_or_none()
if user:
user.token_version = (user.token_version or 1) + 1
await db.commit()