Spaces:
Sleeping
Sleeping
| import random | |
| import uuid | |
| from jose import jwt | |
| from datetime import datetime, timedelta | |
| from fastapi import HTTPException | |
| from app.models.user_model import BookMyServiceUserModel | |
| from app.models.otp_model import BookMyServiceOTPModel | |
| from app.models.social_account_model import SocialAccountModel | |
| from app.models.refresh_token_model import RefreshTokenModel | |
| from app.core.config import settings | |
| from app.utils.common_utils import is_email, validate_identifier | |
| from app.utils.jwt import create_refresh_token | |
| from app.schemas.user_schema import UserRegisterRequest | |
| import logging | |
| logger = logging.getLogger("user_service") | |
| class UserService: | |
| async def send_otp(identifier: str, phone: str = None, client_ip: str = None): | |
| logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}, ip: {client_ip}") | |
| try: | |
| # Validate identifier format | |
| identifier_type = validate_identifier(identifier) | |
| logger.info(f"Identifier type: {identifier_type}") | |
| # Enhanced rate limiting by IP and identifier | |
| if client_ip: | |
| ip_rate_key = f"otp_ip_rate:{client_ip}" | |
| ip_attempts = await BookMyServiceOTPModel.get_rate_limit_count(ip_rate_key) | |
| if ip_attempts >= 10: # Max 10 OTPs per IP per hour | |
| logger.warning(f"IP rate limit exceeded for {client_ip}") | |
| raise HTTPException(status_code=429, detail="Too many OTP requests from this IP") | |
| # For phone identifiers, use the identifier itself as phone | |
| # For email identifiers, use the provided phone parameter | |
| if identifier_type == "phone": | |
| phone_number = identifier | |
| elif identifier_type == "email" and phone: | |
| phone_number = phone | |
| else: | |
| # If email identifier but no phone provided, we'll send OTP via email | |
| phone_number = None | |
| # Generate OTP - hardcoded for testing purposes | |
| otp = '777777' | |
| logger.info(f"Generated hardcoded OTP for identifier: {identifier}") | |
| await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp) | |
| # Track IP-based rate limiting | |
| if client_ip: | |
| await BookMyServiceOTPModel.increment_rate_limit(ip_rate_key, 3600) # 1 hour window | |
| logger.info(f"OTP stored successfully for identifier: {identifier}") | |
| logger.info(f"OTP sent to {identifier}") | |
| except ValueError as ve: | |
| logger.error(f"Validation error for identifier {identifier}: {str(ve)}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"Error in send_otp for identifier {identifier}: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to send OTP") | |
| async def otp_login_handler( | |
| identifier: str, | |
| otp: str, | |
| client_ip: str = None, | |
| remember_me: bool = False, | |
| device_info: str = None | |
| ): | |
| logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}, ip: {client_ip}, remember_me: {remember_me}") | |
| try: | |
| # Validate identifier format | |
| identifier_type = validate_identifier(identifier) | |
| logger.info(f"Identifier type: {identifier_type}") | |
| # Check if account is locked | |
| if await BookMyServiceOTPModel.is_account_locked(identifier): | |
| logger.warning(f"Account locked for identifier: {identifier}") | |
| raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts") | |
| # Verify OTP with client IP tracking | |
| logger.info(f"Verifying OTP for identifier: {identifier}") | |
| otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp, client_ip) | |
| logger.info(f"OTP verification result: {otp_valid}") | |
| if not otp_valid: | |
| logger.warning(f"Invalid or expired OTP for identifier: {identifier}") | |
| # Track failed attempt | |
| await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip) | |
| raise HTTPException(status_code=400, detail="Invalid or expired OTP") | |
| # Clear failed attempts on successful verification | |
| await BookMyServiceOTPModel.clear_failed_attempts(identifier) | |
| logger.info(f"OTP verification successful for identifier: {identifier}") | |
| # Find user by identifier | |
| logger.info(f"Looking up user by identifier: {identifier}") | |
| user = await BookMyServiceUserModel.find_by_identifier(identifier) | |
| logger.info(f"User lookup result: {user is not None}") | |
| if not user: | |
| logger.warning(f"No user found for identifier: {identifier}") | |
| raise HTTPException(status_code=404, detail="User not found") | |
| customer_id = user.get("customer_id") | |
| logger.info(f"User found for identifier: {identifier}, customer_id: {customer_id}") | |
| # Create token family for refresh token rotation | |
| family_id = await RefreshTokenModel.create_token_family(customer_id, device_info) | |
| # Create JWT access token | |
| logger.info("Creating JWT token for authenticated user") | |
| token_data = { | |
| "sub": customer_id, | |
| "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) | |
| } | |
| access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| # Create refresh token with rotation support | |
| refresh_token, token_id, expires_at = create_refresh_token( | |
| {"sub": customer_id}, | |
| remember_me=remember_me, | |
| family_id=family_id | |
| ) | |
| # Store refresh token metadata | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id=customer_id, | |
| family_id=family_id, | |
| expires_at=expires_at, | |
| remember_me=remember_me, | |
| device_info=device_info, | |
| ip_address=client_ip | |
| ) | |
| # Log generated tokens (truncated for security) | |
| logger.info(f"Access token generated (first 25 chars): {access_token[:25]}...") | |
| logger.info(f"Refresh token generated (first 25 chars): {refresh_token[:25]}...") | |
| logger.info(f"JWT tokens created successfully for user: {customer_id}") | |
| return { | |
| "access_token": access_token, | |
| "refresh_token": refresh_token, | |
| "token_type": "bearer", | |
| "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, | |
| "customer_id": customer_id, | |
| "name": user.get("name"), | |
| "email": user.get("email"), | |
| "profile_picture": user.get("profile_picture"), | |
| "auth_method": user.get("auth_mode"), | |
| "provider": None, | |
| "security_info": None | |
| } | |
| except ValueError as ve: | |
| logger.error(f"Validation error for identifier {identifier}: {str(ve)}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except HTTPException as e: | |
| logger.error(f"HTTP error in otp_login_handler for {identifier}: {e.status_code} - {e.detail}") | |
| raise e | |
| except Exception as e: | |
| logger.error(f"Unexpected error in otp_login_handler for {identifier}: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Internal server error during login") | |
| async def register(data: UserRegisterRequest, decoded): | |
| logger.info(f"Registering user with data: {data}") | |
| # Validate mandatory fields for all registration modes | |
| if not data.name or not data.name.strip(): | |
| raise HTTPException(status_code=400, detail="Name is required") | |
| if not data.email: | |
| raise HTTPException(status_code=400, detail="Email is required") | |
| if not data.phone or not data.phone.strip(): | |
| raise HTTPException(status_code=400, detail="Phone is required") | |
| if data.mode == "otp": | |
| # Always use phone as the OTP identifier as per documentation | |
| identifier = data.phone | |
| # Validate phone format | |
| try: | |
| identifier_type = validate_identifier(identifier) | |
| if identifier_type != "phone": | |
| raise ValueError("Phone number format is invalid") | |
| logger.info(f"Registration identifier type: {identifier_type}") | |
| except ValueError as ve: | |
| logger.error(f"Invalid phone format during registration: {str(ve)}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| redis_key = f"bms_otp:{identifier}" | |
| logger.info(f"Verifying OTP for Redis key: {redis_key}") | |
| if not data.otp: | |
| raise HTTPException(status_code=400, detail="OTP is required") | |
| if not await BookMyServiceOTPModel.verify_otp(identifier, data.otp): | |
| raise HTTPException(status_code=400, detail="Invalid or expired OTP") | |
| customer_id = str(uuid.uuid4()) | |
| elif data.mode == "oauth": | |
| # Validate OAuth-specific mandatory fields | |
| if not data.oauth_token or not data.provider: | |
| raise HTTPException(status_code=400, detail="OAuth token and provider are required") | |
| # Extract user info from decoded token | |
| user_info = decoded.get("user_info", {}) | |
| provider_customer_id = user_info.get("sub") or user_info.get("id") | |
| if not provider_customer_id: | |
| raise HTTPException(status_code=400, detail="Invalid OAuth user information") | |
| # Check if this social account already exists | |
| existing_social_account = await SocialAccountModel.find_by_provider_and_customer_id( | |
| data.provider, provider_customer_id | |
| ) | |
| if existing_social_account: | |
| # User already has this social account linked | |
| existing_user = await BookMyServiceUserModel.collection.find_one({ | |
| "customer_id": existing_social_account["customer_id"] | |
| }) | |
| if existing_user: | |
| # Update social account with latest info and return existing user token | |
| await SocialAccountModel.update_social_account(data.provider, provider_customer_id, user_info) | |
| token_data = { | |
| "sub": existing_user["customer_id"], | |
| "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) | |
| } | |
| access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| # Create refresh token | |
| refresh_token_data = { | |
| "sub": existing_user["customer_id"], | |
| "exp": datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) | |
| } | |
| refresh_token = jwt.encode(refresh_token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| # Log generated tokens for existing linked user (truncated) | |
| logger.info(f"Access token for existing user (first 25 chars): {access_token[:25]}...") | |
| logger.info(f"Refresh token for existing user (first 25 chars): {refresh_token[:25]}...") | |
| return { | |
| "access_token": access_token, | |
| "refresh_token": refresh_token, | |
| "token_type": "bearer", | |
| "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 | |
| } | |
| # Generate a new UUID for customer_id instead of provider-prefixed ID | |
| customer_id = str(uuid.uuid4()) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported registration mode") | |
| # Check if user already exists | |
| if await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}): | |
| raise HTTPException(status_code=409, detail="User already registered") | |
| # Check for existing email or phone | |
| existing_user = await BookMyServiceUserModel.exists_by_email_or_phone( | |
| email=data.email, | |
| phone=data.phone | |
| ) | |
| if existing_user: | |
| raise HTTPException(status_code=409, detail="User with this email or phone already exists") | |
| # Create user document | |
| user_doc = { | |
| "customer_id": customer_id, | |
| "name": data.name, | |
| "email": data.email, | |
| "phone": data.phone, | |
| "auth_mode": data.mode, | |
| "created_at": datetime.utcnow() | |
| } | |
| # Add profile picture from social account if available | |
| if data.mode == "oauth" and user_info.get("picture"): | |
| user_doc["profile_picture"] = user_info["picture"] | |
| await BookMyServiceUserModel.collection.insert_one(user_doc) | |
| logger.info(f"Created new user: {customer_id}") | |
| # Create social account record for OAuth registration using UUID customer_id | |
| if data.mode == "oauth": | |
| await SocialAccountModel.create_social_account( | |
| customer_id, data.provider, provider_customer_id, user_info | |
| ) | |
| logger.info(f"Created social account link for {data.provider} -> {customer_id}") | |
| # Create token family for refresh token rotation | |
| family_id = await RefreshTokenModel.create_token_family(customer_id, data.device_info) | |
| token_data = { | |
| "sub": customer_id, | |
| "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) | |
| } | |
| access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) | |
| # Create refresh token with rotation support | |
| refresh_token, token_id, expires_at = create_refresh_token( | |
| {"sub": customer_id}, | |
| remember_me=data.remember_me, | |
| family_id=family_id | |
| ) | |
| # Store refresh token metadata | |
| await RefreshTokenModel.store_refresh_token( | |
| token_id=token_id, | |
| customer_id=customer_id, | |
| family_id=family_id, | |
| expires_at=expires_at, | |
| remember_me=data.remember_me, | |
| device_info=data.device_info, | |
| ip_address=None # Can be passed from router if needed | |
| ) | |
| # Log generated tokens for new registration (truncated) | |
| logger.info(f"Access token on register (first 25 chars): {access_token[:25]}...") | |
| logger.info(f"Refresh token on register (first 25 chars): {refresh_token[:25]}...") | |
| return { | |
| "access_token": access_token, | |
| "refresh_token": refresh_token, | |
| "token_type": "bearer", | |
| "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, | |
| "customer_id": customer_id, | |
| "name": data.name, | |
| "email": data.email, | |
| "profile_picture": user_doc.get("profile_picture"), | |
| "auth_method": data.mode, | |
| "provider": data.provider if data.mode == "oauth" else None, | |
| "security_info": None | |
| } |