|
|
"""
|
|
|
User authentication service for OpenManus
|
|
|
Handles user registration, login, and session management with D1 database
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import sqlite3
|
|
|
from datetime import datetime
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
from app.auth import (
|
|
|
User,
|
|
|
UserAuth,
|
|
|
UserSession,
|
|
|
UserSignupRequest,
|
|
|
UserLoginRequest,
|
|
|
AuthResponse,
|
|
|
UserProfile,
|
|
|
)
|
|
|
from app.logger import logger
|
|
|
|
|
|
|
|
|
class AuthService:
|
|
|
"""Authentication service for user management"""
|
|
|
|
|
|
def __init__(self, db_connection=None):
|
|
|
"""Initialize auth service with database connection"""
|
|
|
self.db = db_connection
|
|
|
self.logger = logger
|
|
|
|
|
|
async def register_user(self, signup_data: UserSignupRequest) -> AuthResponse:
|
|
|
"""Register a new user"""
|
|
|
try:
|
|
|
|
|
|
formatted_mobile = UserAuth.format_mobile_number(signup_data.mobile_number)
|
|
|
|
|
|
|
|
|
existing_user = await self.get_user_by_mobile(formatted_mobile)
|
|
|
if existing_user:
|
|
|
return AuthResponse(
|
|
|
success=False, message="User with this mobile number already exists"
|
|
|
)
|
|
|
|
|
|
|
|
|
user_id = UserAuth.generate_user_id()
|
|
|
password_hash = UserAuth.hash_password(signup_data.password)
|
|
|
|
|
|
user = User(
|
|
|
id=user_id,
|
|
|
mobile_number=formatted_mobile,
|
|
|
full_name=signup_data.full_name,
|
|
|
password_hash=password_hash,
|
|
|
created_at=datetime.utcnow(),
|
|
|
updated_at=datetime.utcnow(),
|
|
|
)
|
|
|
|
|
|
|
|
|
success = await self.save_user(user)
|
|
|
if not success:
|
|
|
return AuthResponse(
|
|
|
success=False, message="Failed to create user account"
|
|
|
)
|
|
|
|
|
|
|
|
|
session = UserAuth.create_session(user)
|
|
|
session_saved = await self.save_session(session)
|
|
|
|
|
|
if not session_saved:
|
|
|
return AuthResponse(
|
|
|
success=False, message="User created but failed to create session"
|
|
|
)
|
|
|
|
|
|
self.logger.info(f"New user registered: {formatted_mobile}")
|
|
|
|
|
|
return AuthResponse(
|
|
|
success=True,
|
|
|
message="Account created successfully",
|
|
|
session_id=session.session_id,
|
|
|
user_id=user.id,
|
|
|
full_name=user.full_name,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"User registration error: {str(e)}")
|
|
|
return AuthResponse(
|
|
|
success=False, message="An error occurred during registration"
|
|
|
)
|
|
|
|
|
|
async def login_user(self, login_data: UserLoginRequest) -> AuthResponse:
|
|
|
"""Authenticate user login"""
|
|
|
try:
|
|
|
|
|
|
formatted_mobile = UserAuth.format_mobile_number(login_data.mobile_number)
|
|
|
|
|
|
|
|
|
user = await self.get_user_by_mobile(formatted_mobile)
|
|
|
if not user:
|
|
|
return AuthResponse(
|
|
|
success=False, message="Invalid mobile number or password"
|
|
|
)
|
|
|
|
|
|
|
|
|
if not UserAuth.verify_password(login_data.password, user.password_hash):
|
|
|
return AuthResponse(
|
|
|
success=False, message="Invalid mobile number or password"
|
|
|
)
|
|
|
|
|
|
|
|
|
if not user.is_active:
|
|
|
return AuthResponse(
|
|
|
success=False,
|
|
|
message="Account is deactivated. Please contact support.",
|
|
|
)
|
|
|
|
|
|
|
|
|
session = UserAuth.create_session(user)
|
|
|
session_saved = await self.save_session(session)
|
|
|
|
|
|
if not session_saved:
|
|
|
return AuthResponse(
|
|
|
success=False,
|
|
|
message="Login successful but failed to create session",
|
|
|
)
|
|
|
|
|
|
self.logger.info(f"User logged in: {formatted_mobile}")
|
|
|
|
|
|
return AuthResponse(
|
|
|
success=True,
|
|
|
message="Login successful",
|
|
|
session_id=session.session_id,
|
|
|
user_id=user.id,
|
|
|
full_name=user.full_name,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"User login error: {str(e)}")
|
|
|
return AuthResponse(success=False, message="An error occurred during login")
|
|
|
|
|
|
async def validate_session(self, session_id: str) -> Optional[UserSession]:
|
|
|
"""Validate user session"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return None
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
SELECT s.id, s.user_id, u.mobile_number, u.full_name,
|
|
|
s.created_at, s.expires_at
|
|
|
FROM sessions s
|
|
|
JOIN users u ON s.user_id = u.id
|
|
|
WHERE s.id = ? AND u.is_active = 1
|
|
|
""",
|
|
|
(session_id,),
|
|
|
)
|
|
|
|
|
|
row = cursor.fetchone()
|
|
|
if not row:
|
|
|
return None
|
|
|
|
|
|
session = UserSession(
|
|
|
session_id=row[0],
|
|
|
user_id=row[1],
|
|
|
mobile_number=row[2],
|
|
|
full_name=row[3],
|
|
|
created_at=datetime.fromisoformat(row[4]),
|
|
|
expires_at=datetime.fromisoformat(row[5]),
|
|
|
)
|
|
|
|
|
|
|
|
|
if not session.is_valid:
|
|
|
|
|
|
await self.delete_session(session_id)
|
|
|
return None
|
|
|
|
|
|
return session
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Session validation error: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
async def logout_user(self, session_id: str) -> bool:
|
|
|
"""Logout user by deleting session"""
|
|
|
return await self.delete_session(session_id)
|
|
|
|
|
|
async def get_user_profile(self, user_id: str) -> Optional[UserProfile]:
|
|
|
"""Get user profile by user ID"""
|
|
|
try:
|
|
|
user = await self.get_user_by_id(user_id)
|
|
|
if not user:
|
|
|
return None
|
|
|
|
|
|
return UserProfile(
|
|
|
user_id=user.id,
|
|
|
full_name=user.full_name,
|
|
|
mobile_number=UserProfile.mask_mobile_number(user.mobile_number),
|
|
|
avatar_url=user.avatar_url,
|
|
|
created_at=user.created_at.isoformat() if user.created_at else None,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Get user profile error: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
async def save_user(self, user: User) -> bool:
|
|
|
"""Save user to database"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return False
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
INSERT INTO users (id, mobile_number, full_name, password_hash,
|
|
|
avatar_url, preferences, is_active, created_at, updated_at)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
""",
|
|
|
(
|
|
|
user.id,
|
|
|
user.mobile_number,
|
|
|
user.full_name,
|
|
|
user.password_hash,
|
|
|
user.avatar_url,
|
|
|
user.preferences,
|
|
|
user.is_active,
|
|
|
user.created_at.isoformat() if user.created_at else None,
|
|
|
user.updated_at.isoformat() if user.updated_at else None,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
self.db.commit()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Save user error: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
async def get_user_by_mobile(self, mobile_number: str) -> Optional[User]:
|
|
|
"""Get user by mobile number"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return None
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
SELECT id, mobile_number, full_name, password_hash, avatar_url,
|
|
|
preferences, is_active, created_at, updated_at
|
|
|
FROM users
|
|
|
WHERE mobile_number = ?
|
|
|
""",
|
|
|
(mobile_number,),
|
|
|
)
|
|
|
|
|
|
row = cursor.fetchone()
|
|
|
if not row:
|
|
|
return None
|
|
|
|
|
|
return User(
|
|
|
id=row[0],
|
|
|
mobile_number=row[1],
|
|
|
full_name=row[2],
|
|
|
password_hash=row[3],
|
|
|
avatar_url=row[4],
|
|
|
preferences=row[5],
|
|
|
is_active=bool(row[6]),
|
|
|
created_at=datetime.fromisoformat(row[7]) if row[7] else None,
|
|
|
updated_at=datetime.fromisoformat(row[8]) if row[8] else None,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Get user by mobile error: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
async def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
|
"""Get user by ID"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return None
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
SELECT id, mobile_number, full_name, password_hash, avatar_url,
|
|
|
preferences, is_active, created_at, updated_at
|
|
|
FROM users
|
|
|
WHERE id = ? AND is_active = 1
|
|
|
""",
|
|
|
(user_id,),
|
|
|
)
|
|
|
|
|
|
row = cursor.fetchone()
|
|
|
if not row:
|
|
|
return None
|
|
|
|
|
|
return User(
|
|
|
id=row[0],
|
|
|
mobile_number=row[1],
|
|
|
full_name=row[2],
|
|
|
password_hash=row[3],
|
|
|
avatar_url=row[4],
|
|
|
preferences=row[5],
|
|
|
is_active=bool(row[6]),
|
|
|
created_at=datetime.fromisoformat(row[7]) if row[7] else None,
|
|
|
updated_at=datetime.fromisoformat(row[8]) if row[8] else None,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Get user by ID error: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
async def save_session(self, session: UserSession) -> bool:
|
|
|
"""Save session to database"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return False
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute(
|
|
|
"""
|
|
|
INSERT INTO sessions (id, user_id, title, metadata, created_at,
|
|
|
updated_at, expires_at)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
|
""",
|
|
|
(
|
|
|
session.session_id,
|
|
|
session.user_id,
|
|
|
"User Session",
|
|
|
json.dumps({"login_type": "mobile_password"}),
|
|
|
session.created_at.isoformat(),
|
|
|
session.created_at.isoformat(),
|
|
|
session.expires_at.isoformat(),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
self.db.commit()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Save session error: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
|
"""Delete session from database"""
|
|
|
try:
|
|
|
if not self.db:
|
|
|
return False
|
|
|
|
|
|
cursor = self.db.cursor()
|
|
|
cursor.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
|
|
self.db.commit()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Delete session error: {str(e)}")
|
|
|
return False
|
|
|
|