File size: 2,735 Bytes
e6ec780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import Any, Optional
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from google_auth_service.user_store import BaseUserStore
from google_auth_service.google_provider import GoogleUserInfo
from core.database import async_session_maker
from core.models import User
import uuid
import logging

logger = logging.getLogger(__name__)

class SQLAlchemyUserStore(BaseUserStore):
    """
    Adapter to allow GoogleAuth library to use SQLAlchemy models.
    """
    
    async def get(self, user_id: str) -> Optional[User]:
        async with async_session_maker() as db:
            query = select(User).where(User.user_id == user_id)
            result = await db.execute(query)
            return result.scalar_one_or_none()
            
    async def save(self, google_info: GoogleUserInfo) -> User:
        async with async_session_maker() as db:
            query = select(User).where(User.email == google_info.email)
            result = await db.execute(query)
            user = result.scalar_one_or_none()
            
            if user:
                # Update existing
                if not user.google_id:
                    user.google_id = google_info.google_id
                user.name = google_info.name
                user.profile_picture = google_info.picture
                user.last_used_at = datetime.utcnow()
            else:
                # Create new
                user = User(
                    user_id="usr_" + str(uuid.uuid4()),
                    email=google_info.email,
                    google_id=google_info.google_id,
                    name=google_info.name,
                    profile_picture=google_info.picture,
                    credits=0, # Business logic
                    token_version=1
                )
                db.add(user)
                logger.info(f"New user created: {user.email}")
                
            await db.commit()
            await db.refresh(user)
            return user
            
    async def get_token_version(self, user_id: str) -> Optional[int]:
        async with async_session_maker() as db:
            query = select(User.token_version).where(User.user_id == user_id)
            result = await db.execute(query)
            return result.scalar_one_or_none()
            
    async def invalidate_token(self, user_id: str) -> None:
        async with async_session_maker() as db:
            query = select(User).where(User.user_id == user_id)
            result = await db.execute(query)
            user = result.scalar_one_or_none()
            if user:
                user.token_version = (user.token_version or 1) + 1
                await db.commit()