Spaces:
Sleeping
Sleeping
Commit ·
d846c9a
1
Parent(s): 91dc3b5
feat(auth): implement enhanced social login with security middleware
Browse files- Add Facebook OAuth support with token verification
- Implement rate limiting and IP-based security for OAuth logins
- Create account management endpoints for social account linking
- Add security middleware for request logging and device tracking
- Enhance OTP verification with account locking and rate limiting
- Update user schemas with social account and security fields
- .env +3 -2
- app/app.py +10 -1
- app/core/config.py +16 -0
- app/middleware/rate_limiter.py +27 -0
- app/middleware/security_middleware.py +237 -0
- app/models/otp_model.py +132 -6
- app/models/social_account_model.py +257 -0
- app/models/social_security_model.py +171 -0
- app/routers/account_router.py +218 -0
- app/routers/user_router.py +81 -18
- app/schemas/user_schema.py +71 -11
- app/services/account_service.py +396 -0
- app/services/user_service.py +105 -18
- app/utils/social_utils.py +71 -1
.env
CHANGED
|
@@ -5,11 +5,12 @@ DB_NAME=book-my-service
|
|
| 5 |
|
| 6 |
DATABASE_URI=postgresql+asyncpg://trans_owner:BookMyService7@ep-sweet-surf-a1qeduoy.ap-southeast-1.aws.neon.tech/bookmyservice?options=-csearch_path%3Dtrans
|
| 7 |
|
| 8 |
-
CACHE_URI=redis-11382.c305.ap-south-1-1.ec2.redns.redis-cloud.com:11382
|
| 9 |
|
| 10 |
#CACHE_URI=redis-11521.crce182.ap-south-1-1.ec2.redns.redis-cloud.com:11521
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
RAZORPAY_KEY_ID=rzp_test_2UTAol2AFSV5VN
|
|
|
|
| 5 |
|
| 6 |
DATABASE_URI=postgresql+asyncpg://trans_owner:BookMyService7@ep-sweet-surf-a1qeduoy.ap-southeast-1.aws.neon.tech/bookmyservice?options=-csearch_path%3Dtrans
|
| 7 |
|
| 8 |
+
#CACHE_URI=redis-11382.c305.ap-south-1-1.ec2.redns.redis-cloud.com:11382
|
| 9 |
|
| 10 |
#CACHE_URI=redis-11521.crce182.ap-south-1-1.ec2.redns.redis-cloud.com:11521
|
| 11 |
|
| 12 |
+
CACHE_URI=localhost:6379
|
| 13 |
+
CACHE_K=
|
| 14 |
|
| 15 |
|
| 16 |
RAZORPAY_KEY_ID=rzp_test_2UTAol2AFSV5VN
|
app/app.py
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
|
| 3 |
from fastapi import FastAPI
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
-
from app.routers import user_router, profile_router
|
|
|
|
|
|
|
| 6 |
import logging
|
| 7 |
import sys
|
| 8 |
|
|
@@ -29,6 +31,12 @@ logging.getLogger("fastapi").setLevel(logging.INFO)
|
|
| 29 |
|
| 30 |
app = FastAPI(title="BookMyService User Management Service")
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
app.add_middleware(
|
| 33 |
CORSMiddleware,
|
| 34 |
allow_origins=["*"],
|
|
@@ -39,6 +47,7 @@ app.add_middleware(
|
|
| 39 |
|
| 40 |
app.include_router(user_router.router, prefix="/auth", tags=["user_auth"])
|
| 41 |
app.include_router(profile_router.router, prefix="/profile", tags=["profile"])
|
|
|
|
| 42 |
|
| 43 |
@app.get("/")
|
| 44 |
def root():
|
|
|
|
| 2 |
|
| 3 |
from fastapi import FastAPI
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from app.routers import user_router, profile_router, account_router
|
| 6 |
+
from app.middleware.rate_limiter import RateLimitMiddleware
|
| 7 |
+
from app.middleware.security_middleware import SecurityMiddleware
|
| 8 |
import logging
|
| 9 |
import sys
|
| 10 |
|
|
|
|
| 31 |
|
| 32 |
app = FastAPI(title="BookMyService User Management Service")
|
| 33 |
|
| 34 |
+
# Add security middleware (should be added first for proper request logging)
|
| 35 |
+
app.add_middleware(SecurityMiddleware)
|
| 36 |
+
|
| 37 |
+
# Add rate limiting middleware
|
| 38 |
+
app.add_middleware(RateLimitMiddleware, calls=100, period=60)
|
| 39 |
+
|
| 40 |
app.add_middleware(
|
| 41 |
CORSMiddleware,
|
| 42 |
allow_origins=["*"],
|
|
|
|
| 47 |
|
| 48 |
app.include_router(user_router.router, prefix="/auth", tags=["user_auth"])
|
| 49 |
app.include_router(profile_router.router, prefix="/profile", tags=["profile"])
|
| 50 |
+
app.include_router(account_router.router, prefix="/account", tags=["account_management"])
|
| 51 |
|
| 52 |
@app.get("/")
|
| 53 |
def root():
|
app/core/config.py
CHANGED
|
@@ -12,19 +12,35 @@ class Settings:
|
|
| 12 |
CACHE_URI: str = os.getenv("CACHE_URI")
|
| 13 |
CACHE_K: str = os.getenv("CACHE_K")
|
| 14 |
|
|
|
|
| 15 |
SECRET_KEY: str = os.getenv("SECRET_KEY", "B00Kmyservice@7")
|
| 16 |
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 17 |
|
|
|
|
| 18 |
TWILIO_ACCOUNT_SID: str = os.getenv("TWILIO_ACCOUNT_SID")
|
| 19 |
TWILIO_AUTH_TOKEN: str = os.getenv("TWILIO_AUTH_TOKEN")
|
| 20 |
TWILIO_SMS_FROM: str = os.getenv("TWILIO_SMS_FROM")
|
| 21 |
|
|
|
|
| 22 |
SMTP_HOST: str = os.getenv("SMTP_HOST")
|
| 23 |
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
| 24 |
SMTP_USER: str = os.getenv("SMTP_USER")
|
| 25 |
SMTP_PASS: str = os.getenv("SMTP_PASS")
|
| 26 |
SMTP_FROM: str = os.getenv("SMTP_FROM")
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def __post_init__(self):
|
| 29 |
if not self.MONGO_URI or not self.DB_NAME:
|
| 30 |
raise ValueError("MongoDB URI or DB_NAME not configured.")
|
|
|
|
| 12 |
CACHE_URI: str = os.getenv("CACHE_URI")
|
| 13 |
CACHE_K: str = os.getenv("CACHE_K")
|
| 14 |
|
| 15 |
+
# JWT
|
| 16 |
SECRET_KEY: str = os.getenv("SECRET_KEY", "B00Kmyservice@7")
|
| 17 |
ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
|
| 18 |
|
| 19 |
+
# Twilio SMS
|
| 20 |
TWILIO_ACCOUNT_SID: str = os.getenv("TWILIO_ACCOUNT_SID")
|
| 21 |
TWILIO_AUTH_TOKEN: str = os.getenv("TWILIO_AUTH_TOKEN")
|
| 22 |
TWILIO_SMS_FROM: str = os.getenv("TWILIO_SMS_FROM")
|
| 23 |
|
| 24 |
+
# SMTP Email
|
| 25 |
SMTP_HOST: str = os.getenv("SMTP_HOST")
|
| 26 |
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
| 27 |
SMTP_USER: str = os.getenv("SMTP_USER")
|
| 28 |
SMTP_PASS: str = os.getenv("SMTP_PASS")
|
| 29 |
SMTP_FROM: str = os.getenv("SMTP_FROM")
|
| 30 |
|
| 31 |
+
# OAuth Providers
|
| 32 |
+
GOOGLE_CLIENT_ID: str = os.getenv("GOOGLE_CLIENT_ID")
|
| 33 |
+
APPLE_AUDIENCE: str = os.getenv("APPLE_AUDIENCE")
|
| 34 |
+
FACEBOOK_APP_ID: str = os.getenv("FACEBOOK_APP_ID")
|
| 35 |
+
FACEBOOK_APP_SECRET: str = os.getenv("FACEBOOK_APP_SECRET")
|
| 36 |
+
|
| 37 |
+
# Security Settings
|
| 38 |
+
MAX_LOGIN_ATTEMPTS: int = int(os.getenv("MAX_LOGIN_ATTEMPTS", "5"))
|
| 39 |
+
ACCOUNT_LOCK_DURATION: int = int(os.getenv("ACCOUNT_LOCK_DURATION", "900")) # 15 minutes
|
| 40 |
+
OTP_VALIDITY_MINUTES: int = int(os.getenv("OTP_VALIDITY_MINUTES", "5"))
|
| 41 |
+
IP_RATE_LIMIT_MAX: int = int(os.getenv("IP_RATE_LIMIT_MAX", "10"))
|
| 42 |
+
IP_RATE_LIMIT_WINDOW: int = int(os.getenv("IP_RATE_LIMIT_WINDOW", "3600")) # 1 hour
|
| 43 |
+
|
| 44 |
def __post_init__(self):
|
| 45 |
if not self.MONGO_URI or not self.DB_NAME:
|
| 46 |
raise ValueError("MongoDB URI or DB_NAME not configured.")
|
app/middleware/rate_limiter.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, HTTPException
|
| 2 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 3 |
+
import time
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
|
| 6 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 7 |
+
def __init__(self, app, calls: int = 100, period: int = 60):
|
| 8 |
+
super().__init__(app)
|
| 9 |
+
self.calls = calls
|
| 10 |
+
self.period = period
|
| 11 |
+
self.clients = defaultdict(deque)
|
| 12 |
+
|
| 13 |
+
async def dispatch(self, request: Request, call_next):
|
| 14 |
+
client_ip = request.client.host
|
| 15 |
+
now = time.time()
|
| 16 |
+
|
| 17 |
+
# Clean old requests
|
| 18 |
+
while self.clients[client_ip] and self.clients[client_ip][0] <= now - self.period:
|
| 19 |
+
self.clients[client_ip].popleft()
|
| 20 |
+
|
| 21 |
+
# Check rate limit
|
| 22 |
+
if len(self.clients[client_ip]) >= self.calls:
|
| 23 |
+
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
| 24 |
+
|
| 25 |
+
self.clients[client_ip].append(now)
|
| 26 |
+
response = await call_next(request)
|
| 27 |
+
return response
|
app/middleware/security_middleware.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, Response
|
| 2 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, Any
|
| 7 |
+
from app.core.nosql_client import db
|
| 8 |
+
|
| 9 |
+
# Configure logging
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class SecurityMiddleware(BaseHTTPMiddleware):
|
| 14 |
+
"""
|
| 15 |
+
Enhanced security middleware for request logging, device tracking, and security monitoring
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, app):
|
| 19 |
+
super().__init__(app)
|
| 20 |
+
self.security_collection = db.security_logs
|
| 21 |
+
self.device_collection = db.device_tracking
|
| 22 |
+
|
| 23 |
+
def get_client_ip(self, request: Request) -> str:
|
| 24 |
+
"""Extract client IP from request headers"""
|
| 25 |
+
# Check for forwarded headers first (for proxy/load balancer scenarios)
|
| 26 |
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
| 27 |
+
if forwarded_for:
|
| 28 |
+
return forwarded_for.split(",")[0].strip()
|
| 29 |
+
|
| 30 |
+
real_ip = request.headers.get("X-Real-IP")
|
| 31 |
+
if real_ip:
|
| 32 |
+
return real_ip
|
| 33 |
+
|
| 34 |
+
# Fallback to direct client IP
|
| 35 |
+
return request.client.host if request.client else "unknown"
|
| 36 |
+
|
| 37 |
+
def extract_device_info(self, request: Request) -> Dict[str, Any]:
|
| 38 |
+
"""Extract device and browser information from request headers"""
|
| 39 |
+
user_agent = request.headers.get("User-Agent", "")
|
| 40 |
+
accept_language = request.headers.get("Accept-Language", "")
|
| 41 |
+
accept_encoding = request.headers.get("Accept-Encoding", "")
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"user_agent": user_agent,
|
| 45 |
+
"accept_language": accept_language,
|
| 46 |
+
"accept_encoding": accept_encoding,
|
| 47 |
+
"platform": self._parse_platform(user_agent),
|
| 48 |
+
"browser": self._parse_browser(user_agent)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def _parse_platform(self, user_agent: str) -> str:
|
| 52 |
+
"""Parse platform from user agent string"""
|
| 53 |
+
user_agent_lower = user_agent.lower()
|
| 54 |
+
|
| 55 |
+
if "windows" in user_agent_lower:
|
| 56 |
+
return "Windows"
|
| 57 |
+
elif "macintosh" in user_agent_lower or "mac os" in user_agent_lower:
|
| 58 |
+
return "macOS"
|
| 59 |
+
elif "linux" in user_agent_lower:
|
| 60 |
+
return "Linux"
|
| 61 |
+
elif "android" in user_agent_lower:
|
| 62 |
+
return "Android"
|
| 63 |
+
elif "iphone" in user_agent_lower or "ipad" in user_agent_lower:
|
| 64 |
+
return "iOS"
|
| 65 |
+
else:
|
| 66 |
+
return "Unknown"
|
| 67 |
+
|
| 68 |
+
def _parse_browser(self, user_agent: str) -> str:
|
| 69 |
+
"""Parse browser from user agent string"""
|
| 70 |
+
user_agent_lower = user_agent.lower()
|
| 71 |
+
|
| 72 |
+
if "chrome" in user_agent_lower and "edg" not in user_agent_lower:
|
| 73 |
+
return "Chrome"
|
| 74 |
+
elif "firefox" in user_agent_lower:
|
| 75 |
+
return "Firefox"
|
| 76 |
+
elif "safari" in user_agent_lower and "chrome" not in user_agent_lower:
|
| 77 |
+
return "Safari"
|
| 78 |
+
elif "edg" in user_agent_lower:
|
| 79 |
+
return "Edge"
|
| 80 |
+
elif "opera" in user_agent_lower:
|
| 81 |
+
return "Opera"
|
| 82 |
+
else:
|
| 83 |
+
return "Unknown"
|
| 84 |
+
|
| 85 |
+
def is_sensitive_endpoint(self, path: str) -> bool:
|
| 86 |
+
"""Check if the endpoint is security-sensitive and should be logged"""
|
| 87 |
+
sensitive_paths = [
|
| 88 |
+
"/auth/",
|
| 89 |
+
"/login",
|
| 90 |
+
"/register",
|
| 91 |
+
"/otp",
|
| 92 |
+
"/oauth",
|
| 93 |
+
"/profile",
|
| 94 |
+
"/account",
|
| 95 |
+
"/security"
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
return any(sensitive_path in path for sensitive_path in sensitive_paths)
|
| 99 |
+
|
| 100 |
+
async def log_security_event(self, request: Request, response: Response,
|
| 101 |
+
processing_time: float, client_ip: str,
|
| 102 |
+
device_info: Dict[str, Any]):
|
| 103 |
+
"""Log security-relevant events to database"""
|
| 104 |
+
try:
|
| 105 |
+
# Only log sensitive endpoints or failed requests
|
| 106 |
+
if not (self.is_sensitive_endpoint(str(request.url.path)) or response.status_code >= 400):
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
log_entry = {
|
| 110 |
+
"timestamp": datetime.utcnow(),
|
| 111 |
+
"method": request.method,
|
| 112 |
+
"path": str(request.url.path),
|
| 113 |
+
"query_params": dict(request.query_params),
|
| 114 |
+
"client_ip": client_ip,
|
| 115 |
+
"status_code": response.status_code,
|
| 116 |
+
"processing_time_ms": round(processing_time * 1000, 2),
|
| 117 |
+
"device_info": device_info,
|
| 118 |
+
"headers": {
|
| 119 |
+
"user_agent": request.headers.get("User-Agent", ""),
|
| 120 |
+
"referer": request.headers.get("Referer", ""),
|
| 121 |
+
"content_type": request.headers.get("Content-Type", "")
|
| 122 |
+
},
|
| 123 |
+
"is_suspicious": self._detect_suspicious_activity(request, response, client_ip)
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# Add user ID if available from JWT token
|
| 127 |
+
auth_header = request.headers.get("Authorization")
|
| 128 |
+
if auth_header and auth_header.startswith("Bearer "):
|
| 129 |
+
try:
|
| 130 |
+
from app.utils.jwt import decode_token
|
| 131 |
+
token = auth_header.split(" ")[1]
|
| 132 |
+
payload = decode_token(token)
|
| 133 |
+
log_entry["user_id"] = payload.get("user_id")
|
| 134 |
+
except Exception:
|
| 135 |
+
pass # Token might be invalid or expired
|
| 136 |
+
|
| 137 |
+
await self.security_collection.insert_one(log_entry)
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Failed to log security event: {str(e)}")
|
| 141 |
+
|
| 142 |
+
async def track_device(self, client_ip: str, device_info: Dict[str, Any],
|
| 143 |
+
user_id: str = None):
|
| 144 |
+
"""Track device information for security monitoring"""
|
| 145 |
+
try:
|
| 146 |
+
device_fingerprint = f"{client_ip}_{device_info.get('user_agent', '')[:100]}"
|
| 147 |
+
|
| 148 |
+
device_entry = {
|
| 149 |
+
"device_fingerprint": device_fingerprint,
|
| 150 |
+
"client_ip": client_ip,
|
| 151 |
+
"device_info": device_info,
|
| 152 |
+
"first_seen": datetime.utcnow(),
|
| 153 |
+
"last_seen": datetime.utcnow(),
|
| 154 |
+
"user_id": user_id,
|
| 155 |
+
"access_count": 1,
|
| 156 |
+
"is_trusted": False
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# Update or insert device tracking
|
| 160 |
+
await self.device_collection.update_one(
|
| 161 |
+
{"device_fingerprint": device_fingerprint},
|
| 162 |
+
{
|
| 163 |
+
"$set": {
|
| 164 |
+
"last_seen": datetime.utcnow(),
|
| 165 |
+
"device_info": device_info
|
| 166 |
+
},
|
| 167 |
+
"$inc": {"access_count": 1},
|
| 168 |
+
"$setOnInsert": {
|
| 169 |
+
"device_fingerprint": device_fingerprint,
|
| 170 |
+
"client_ip": client_ip,
|
| 171 |
+
"first_seen": datetime.utcnow(),
|
| 172 |
+
"user_id": user_id,
|
| 173 |
+
"is_trusted": False
|
| 174 |
+
}
|
| 175 |
+
},
|
| 176 |
+
upsert=True
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Failed to track device: {str(e)}")
|
| 181 |
+
|
| 182 |
+
def _detect_suspicious_activity(self, request: Request, response: Response,
|
| 183 |
+
client_ip: str) -> bool:
|
| 184 |
+
"""Detect potentially suspicious activity patterns"""
|
| 185 |
+
suspicious_indicators = []
|
| 186 |
+
|
| 187 |
+
# Check for multiple failed login attempts
|
| 188 |
+
if response.status_code == 401 and "login" in str(request.url.path):
|
| 189 |
+
suspicious_indicators.append("failed_login")
|
| 190 |
+
|
| 191 |
+
# Check for unusual user agent patterns
|
| 192 |
+
user_agent = request.headers.get("User-Agent", "")
|
| 193 |
+
if not user_agent or len(user_agent) < 10:
|
| 194 |
+
suspicious_indicators.append("suspicious_user_agent")
|
| 195 |
+
|
| 196 |
+
# Check for rapid requests (basic detection)
|
| 197 |
+
if hasattr(request.state, "request_count") and request.state.request_count > 10:
|
| 198 |
+
suspicious_indicators.append("rapid_requests")
|
| 199 |
+
|
| 200 |
+
# Check for access to sensitive endpoints without proper authentication
|
| 201 |
+
if (self.is_sensitive_endpoint(str(request.url.path)) and
|
| 202 |
+
response.status_code == 403 and
|
| 203 |
+
not request.headers.get("Authorization")):
|
| 204 |
+
suspicious_indicators.append("unauthorized_sensitive_access")
|
| 205 |
+
|
| 206 |
+
return len(suspicious_indicators) > 0
|
| 207 |
+
|
| 208 |
+
async def dispatch(self, request: Request, call_next):
|
| 209 |
+
"""Main middleware dispatch method"""
|
| 210 |
+
start_time = datetime.utcnow()
|
| 211 |
+
|
| 212 |
+
# Extract client information
|
| 213 |
+
client_ip = self.get_client_ip(request)
|
| 214 |
+
device_info = self.extract_device_info(request)
|
| 215 |
+
|
| 216 |
+
# Process the request
|
| 217 |
+
response = await call_next(request)
|
| 218 |
+
|
| 219 |
+
# Calculate processing time
|
| 220 |
+
end_time = datetime.utcnow()
|
| 221 |
+
processing_time = (end_time - start_time).total_seconds()
|
| 222 |
+
|
| 223 |
+
# Log security events asynchronously
|
| 224 |
+
try:
|
| 225 |
+
await self.log_security_event(request, response, processing_time,
|
| 226 |
+
client_ip, device_info)
|
| 227 |
+
await self.track_device(client_ip, device_info)
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"Security middleware error: {str(e)}")
|
| 230 |
+
|
| 231 |
+
# Add security headers to response
|
| 232 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 233 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 234 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 235 |
+
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| 236 |
+
|
| 237 |
+
return response
|
app/models/otp_model.py
CHANGED
|
@@ -11,16 +11,26 @@ class BookMyServiceOTPModel:
|
|
| 11 |
OTP_TTL = 300 # 5 minutes
|
| 12 |
RATE_LIMIT_MAX = 3
|
| 13 |
RATE_LIMIT_WINDOW = 600 # 10 minutes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
@staticmethod
|
| 16 |
-
async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL):
|
| 17 |
-
logger.info(f"Storing OTP for identifier: {identifier}")
|
| 18 |
|
| 19 |
try:
|
| 20 |
redis = await get_redis()
|
| 21 |
logger.debug(f"Redis connection established for OTP storage")
|
| 22 |
|
| 23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
rate_key = f"otp_rate_limit:{identifier}"
|
| 25 |
logger.debug(f"Checking rate limit with key: {rate_key}")
|
| 26 |
|
|
@@ -34,6 +44,17 @@ class BookMyServiceOTPModel:
|
|
| 34 |
logger.warning(f"Rate limit exceeded for {identifier}: {attempts} attempts")
|
| 35 |
raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.")
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Store OTP
|
| 38 |
otp_key = f"bms_otp:{identifier}"
|
| 39 |
await redis.setex(otp_key, ttl, otp)
|
|
@@ -62,14 +83,19 @@ class BookMyServiceOTPModel:
|
|
| 62 |
raise HTTPException(status_code=500, detail="SMS failed and no email fallback available.")
|
| 63 |
'''
|
| 64 |
@staticmethod
|
| 65 |
-
async def verify_otp(identifier: str, otp: str):
|
| 66 |
-
logger.info(f"Verifying OTP for identifier: {identifier}")
|
| 67 |
logger.debug(f"Provided OTP: {otp}")
|
| 68 |
|
| 69 |
try:
|
| 70 |
redis = await get_redis()
|
| 71 |
logger.debug("Redis connection established for OTP verification")
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
key = f"bms_otp:{identifier}"
|
| 74 |
logger.debug(f"Looking up OTP with key: {key}")
|
| 75 |
|
|
@@ -81,15 +107,24 @@ class BookMyServiceOTPModel:
|
|
| 81 |
if stored == otp:
|
| 82 |
logger.info(f"OTP verification successful for {identifier}")
|
| 83 |
await redis.delete(key)
|
|
|
|
|
|
|
| 84 |
logger.debug(f"OTP deleted from Redis after successful verification")
|
| 85 |
return True
|
| 86 |
else:
|
| 87 |
logger.warning(f"OTP mismatch for {identifier}: provided='{otp}' vs stored='{stored}'")
|
|
|
|
|
|
|
| 88 |
return False
|
| 89 |
else:
|
| 90 |
logger.warning(f"No OTP found in Redis for identifier: {identifier} with key: {key}")
|
|
|
|
|
|
|
| 91 |
return False
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
except Exception as e:
|
| 94 |
logger.error(f"Error verifying OTP for {identifier}: {str(e)}", exc_info=True)
|
| 95 |
return False
|
|
@@ -101,4 +136,95 @@ class BookMyServiceOTPModel:
|
|
| 101 |
otp = await redis.get(key)
|
| 102 |
if otp:
|
| 103 |
return otp
|
| 104 |
-
raise HTTPException(status_code=404, detail="OTP not found or expired")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
OTP_TTL = 300 # 5 minutes
|
| 12 |
RATE_LIMIT_MAX = 3
|
| 13 |
RATE_LIMIT_WINDOW = 600 # 10 minutes
|
| 14 |
+
IP_RATE_LIMIT_MAX = 10 # Max 10 OTPs per IP per hour
|
| 15 |
+
IP_RATE_LIMIT_WINDOW = 3600 # 1 hour
|
| 16 |
+
FAILED_ATTEMPTS_MAX = 5 # Max 5 failed attempts before lock
|
| 17 |
+
FAILED_ATTEMPTS_WINDOW = 3600 # 1 hour
|
| 18 |
+
ACCOUNT_LOCK_DURATION = 1800 # 30 minutes
|
| 19 |
|
| 20 |
@staticmethod
|
| 21 |
+
async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL, client_ip: str = None):
|
| 22 |
+
logger.info(f"Storing OTP for identifier: {identifier}, IP: {client_ip}")
|
| 23 |
|
| 24 |
try:
|
| 25 |
redis = await get_redis()
|
| 26 |
logger.debug(f"Redis connection established for OTP storage")
|
| 27 |
|
| 28 |
+
# Check if account is locked
|
| 29 |
+
if await BookMyServiceOTPModel.is_account_locked(identifier):
|
| 30 |
+
logger.warning(f"Account locked for identifier: {identifier}")
|
| 31 |
+
raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
|
| 32 |
+
|
| 33 |
+
# Rate limit: max 3 OTPs per identifier per 10 minutes
|
| 34 |
rate_key = f"otp_rate_limit:{identifier}"
|
| 35 |
logger.debug(f"Checking rate limit with key: {rate_key}")
|
| 36 |
|
|
|
|
| 44 |
logger.warning(f"Rate limit exceeded for {identifier}: {attempts} attempts")
|
| 45 |
raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.")
|
| 46 |
|
| 47 |
+
# IP-based rate limiting
|
| 48 |
+
if client_ip:
|
| 49 |
+
ip_rate_key = f"otp_ip_rate_limit:{client_ip}"
|
| 50 |
+
ip_attempts = await redis.incr(ip_rate_key)
|
| 51 |
+
|
| 52 |
+
if ip_attempts == 1:
|
| 53 |
+
await redis.expire(ip_rate_key, BookMyServiceOTPModel.IP_RATE_LIMIT_WINDOW)
|
| 54 |
+
elif ip_attempts > BookMyServiceOTPModel.IP_RATE_LIMIT_MAX:
|
| 55 |
+
logger.warning(f"IP rate limit exceeded for {client_ip}: {ip_attempts} attempts")
|
| 56 |
+
raise HTTPException(status_code=429, detail="Too many OTP requests from this IP address")
|
| 57 |
+
|
| 58 |
# Store OTP
|
| 59 |
otp_key = f"bms_otp:{identifier}"
|
| 60 |
await redis.setex(otp_key, ttl, otp)
|
|
|
|
| 83 |
raise HTTPException(status_code=500, detail="SMS failed and no email fallback available.")
|
| 84 |
'''
|
| 85 |
@staticmethod
|
| 86 |
+
async def verify_otp(identifier: str, otp: str, client_ip: str = None):
|
| 87 |
+
logger.info(f"Verifying OTP for identifier: {identifier}, IP: {client_ip}")
|
| 88 |
logger.debug(f"Provided OTP: {otp}")
|
| 89 |
|
| 90 |
try:
|
| 91 |
redis = await get_redis()
|
| 92 |
logger.debug("Redis connection established for OTP verification")
|
| 93 |
|
| 94 |
+
# Check if account is locked
|
| 95 |
+
if await BookMyServiceOTPModel.is_account_locked(identifier):
|
| 96 |
+
logger.warning(f"Account locked for identifier: {identifier}")
|
| 97 |
+
raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
|
| 98 |
+
|
| 99 |
key = f"bms_otp:{identifier}"
|
| 100 |
logger.debug(f"Looking up OTP with key: {key}")
|
| 101 |
|
|
|
|
| 107 |
if stored == otp:
|
| 108 |
logger.info(f"OTP verification successful for {identifier}")
|
| 109 |
await redis.delete(key)
|
| 110 |
+
# Clear failed attempts on successful verification
|
| 111 |
+
await BookMyServiceOTPModel.clear_failed_attempts(identifier)
|
| 112 |
logger.debug(f"OTP deleted from Redis after successful verification")
|
| 113 |
return True
|
| 114 |
else:
|
| 115 |
logger.warning(f"OTP mismatch for {identifier}: provided='{otp}' vs stored='{stored}'")
|
| 116 |
+
# Track failed attempt
|
| 117 |
+
await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
|
| 118 |
return False
|
| 119 |
else:
|
| 120 |
logger.warning(f"No OTP found in Redis for identifier: {identifier} with key: {key}")
|
| 121 |
+
# Track failed attempt for expired/non-existent OTP
|
| 122 |
+
await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
|
| 123 |
return False
|
| 124 |
|
| 125 |
+
except HTTPException as e:
|
| 126 |
+
logger.error(f"HTTP error verifying OTP for {identifier}: {e.status_code} - {e.detail}")
|
| 127 |
+
raise e
|
| 128 |
except Exception as e:
|
| 129 |
logger.error(f"Error verifying OTP for {identifier}: {str(e)}", exc_info=True)
|
| 130 |
return False
|
|
|
|
| 136 |
otp = await redis.get(key)
|
| 137 |
if otp:
|
| 138 |
return otp
|
| 139 |
+
raise HTTPException(status_code=404, detail="OTP not found or expired")
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
async def track_failed_attempt(identifier: str, client_ip: str = None):
|
| 143 |
+
"""Track failed OTP verification attempts"""
|
| 144 |
+
logger.info(f"Tracking failed attempt for identifier: {identifier}, IP: {client_ip}")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
redis = await get_redis()
|
| 148 |
+
|
| 149 |
+
# Track failed attempts for identifier
|
| 150 |
+
failed_key = f"failed_otp:{identifier}"
|
| 151 |
+
attempts = await redis.incr(failed_key)
|
| 152 |
+
|
| 153 |
+
if attempts == 1:
|
| 154 |
+
await redis.expire(failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW)
|
| 155 |
+
|
| 156 |
+
# Lock account if too many failed attempts
|
| 157 |
+
if attempts >= BookMyServiceOTPModel.FAILED_ATTEMPTS_MAX:
|
| 158 |
+
await BookMyServiceOTPModel.lock_account(identifier)
|
| 159 |
+
logger.warning(f"Account locked for {identifier} after {attempts} failed attempts")
|
| 160 |
+
|
| 161 |
+
# Track IP-based failed attempts
|
| 162 |
+
if client_ip:
|
| 163 |
+
ip_failed_key = f"failed_otp_ip:{client_ip}"
|
| 164 |
+
ip_attempts = await redis.incr(ip_failed_key)
|
| 165 |
+
|
| 166 |
+
if ip_attempts == 1:
|
| 167 |
+
await redis.expire(ip_failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW)
|
| 168 |
+
|
| 169 |
+
logger.debug(f"IP {client_ip} failed attempts: {ip_attempts}")
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Error tracking failed attempt for {identifier}: {str(e)}", exc_info=True)
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
async def clear_failed_attempts(identifier: str):
|
| 176 |
+
"""Clear failed attempts counter on successful verification"""
|
| 177 |
+
try:
|
| 178 |
+
redis = await get_redis()
|
| 179 |
+
failed_key = f"failed_otp:{identifier}"
|
| 180 |
+
await redis.delete(failed_key)
|
| 181 |
+
logger.debug(f"Cleared failed attempts for {identifier}")
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"Error clearing failed attempts for {identifier}: {str(e)}", exc_info=True)
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
async def lock_account(identifier: str):
|
| 187 |
+
"""Lock account temporarily"""
|
| 188 |
+
try:
|
| 189 |
+
redis = await get_redis()
|
| 190 |
+
lock_key = f"account_locked:{identifier}"
|
| 191 |
+
await redis.setex(lock_key, BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION, "locked")
|
| 192 |
+
logger.info(f"Account locked for {identifier} for {BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION} seconds")
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"Error locking account for {identifier}: {str(e)}", exc_info=True)
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
async def is_account_locked(identifier: str) -> bool:
|
| 198 |
+
"""Check if account is currently locked"""
|
| 199 |
+
try:
|
| 200 |
+
redis = await get_redis()
|
| 201 |
+
lock_key = f"account_locked:{identifier}"
|
| 202 |
+
locked = await redis.get(lock_key)
|
| 203 |
+
return locked is not None
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error checking account lock for {identifier}: {str(e)}", exc_info=True)
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
async def get_rate_limit_count(rate_key: str) -> int:
|
| 210 |
+
"""Get current rate limit count for a key"""
|
| 211 |
+
try:
|
| 212 |
+
redis = await get_redis()
|
| 213 |
+
count = await redis.get(rate_key)
|
| 214 |
+
return int(count) if count else 0
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"Error getting rate limit count for {rate_key}: {str(e)}", exc_info=True)
|
| 217 |
+
return 0
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
async def increment_rate_limit(rate_key: str, window: int) -> int:
|
| 221 |
+
"""Increment rate limit counter with expiry"""
|
| 222 |
+
try:
|
| 223 |
+
redis = await get_redis()
|
| 224 |
+
count = await redis.incr(rate_key)
|
| 225 |
+
if count == 1:
|
| 226 |
+
await redis.expire(rate_key, window)
|
| 227 |
+
return count
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"Error incrementing rate limit for {rate_key}: {str(e)}", exc_info=True)
|
| 230 |
+
return 0
|
app/models/social_account_model.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException
|
| 2 |
+
from app.core.nosql_client import db
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger("social_account_model")
|
| 8 |
+
|
| 9 |
+
class SocialAccountModel:
|
| 10 |
+
"""Model for managing social login accounts and linking"""
|
| 11 |
+
|
| 12 |
+
collection = db["social_accounts"]
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
async def create_social_account(user_id: str, provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> str:
|
| 16 |
+
"""Create a new social account record"""
|
| 17 |
+
try:
|
| 18 |
+
social_account = {
|
| 19 |
+
"user_id": user_id,
|
| 20 |
+
"provider": provider,
|
| 21 |
+
"provider_user_id": provider_user_id,
|
| 22 |
+
"email": user_info.get("email"),
|
| 23 |
+
"name": user_info.get("name"),
|
| 24 |
+
"picture": user_info.get("picture"),
|
| 25 |
+
"profile_data": user_info,
|
| 26 |
+
"created_at": datetime.utcnow(),
|
| 27 |
+
"updated_at": datetime.utcnow(),
|
| 28 |
+
"is_active": True,
|
| 29 |
+
"last_login": datetime.utcnow()
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
result = await SocialAccountModel.collection.insert_one(social_account)
|
| 33 |
+
logger.info(f"Created social account for user {user_id} with provider {provider}")
|
| 34 |
+
return str(result.inserted_id)
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
logger.error(f"Error creating social account: {str(e)}", exc_info=True)
|
| 38 |
+
raise HTTPException(status_code=500, detail="Failed to create social account")
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
async def find_by_provider_and_user_id(provider: str, provider_user_id: str) -> Optional[Dict[str, Any]]:
|
| 42 |
+
"""Find social account by provider and provider user ID"""
|
| 43 |
+
try:
|
| 44 |
+
account = await SocialAccountModel.collection.find_one({
|
| 45 |
+
"provider": provider,
|
| 46 |
+
"provider_user_id": provider_user_id,
|
| 47 |
+
"is_active": True
|
| 48 |
+
})
|
| 49 |
+
return account
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.error(f"Error finding social account: {str(e)}", exc_info=True)
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
async def find_by_user_id(user_id: str) -> List[Dict[str, Any]]:
|
| 56 |
+
"""Find all social accounts for a user"""
|
| 57 |
+
try:
|
| 58 |
+
cursor = SocialAccountModel.collection.find({
|
| 59 |
+
"user_id": user_id,
|
| 60 |
+
"is_active": True
|
| 61 |
+
})
|
| 62 |
+
accounts = await cursor.to_list(length=None)
|
| 63 |
+
return accounts
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Error finding social accounts for user {user_id}: {str(e)}", exc_info=True)
|
| 66 |
+
return []
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
async def update_social_account(provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> bool:
|
| 70 |
+
"""Update social account with latest user info"""
|
| 71 |
+
try:
|
| 72 |
+
update_data = {
|
| 73 |
+
"email": user_info.get("email"),
|
| 74 |
+
"name": user_info.get("name"),
|
| 75 |
+
"picture": user_info.get("picture"),
|
| 76 |
+
"profile_data": user_info,
|
| 77 |
+
"updated_at": datetime.utcnow(),
|
| 78 |
+
"last_login": datetime.utcnow()
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
result = await SocialAccountModel.collection.update_one(
|
| 82 |
+
{
|
| 83 |
+
"provider": provider,
|
| 84 |
+
"provider_user_id": provider_user_id,
|
| 85 |
+
"is_active": True
|
| 86 |
+
},
|
| 87 |
+
{"$set": update_data}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return result.modified_count > 0
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Error updating social account: {str(e)}", exc_info=True)
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
async def link_social_account(user_id: str, provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> bool:
|
| 98 |
+
"""Link a social account to an existing user"""
|
| 99 |
+
try:
|
| 100 |
+
# Check if this social account is already linked to another user
|
| 101 |
+
existing_account = await SocialAccountModel.find_by_provider_and_user_id(provider, provider_user_id)
|
| 102 |
+
|
| 103 |
+
if existing_account and existing_account["user_id"] != user_id:
|
| 104 |
+
logger.warning(f"Social account {provider}:{provider_user_id} already linked to user {existing_account['user_id']}")
|
| 105 |
+
raise HTTPException(
|
| 106 |
+
status_code=409,
|
| 107 |
+
detail=f"This {provider} account is already linked to another user"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if existing_account and existing_account["user_id"] == user_id:
|
| 111 |
+
# Update existing account
|
| 112 |
+
await SocialAccountModel.update_social_account(provider, provider_user_id, user_info)
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
# Create new social account link
|
| 116 |
+
await SocialAccountModel.create_social_account(user_id, provider, provider_user_id, user_info)
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
except HTTPException:
|
| 120 |
+
raise
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error linking social account: {str(e)}", exc_info=True)
|
| 123 |
+
raise HTTPException(status_code=500, detail="Failed to link social account")
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
async def unlink_social_account(user_id: str, provider: str) -> bool:
|
| 127 |
+
"""Unlink a social account from a user"""
|
| 128 |
+
try:
|
| 129 |
+
result = await SocialAccountModel.collection.update_one(
|
| 130 |
+
{
|
| 131 |
+
"user_id": user_id,
|
| 132 |
+
"provider": provider,
|
| 133 |
+
"is_active": True
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"$set": {
|
| 137 |
+
"is_active": False,
|
| 138 |
+
"updated_at": datetime.utcnow()
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if result.modified_count > 0:
|
| 144 |
+
logger.info(f"Unlinked {provider} account for user {user_id}")
|
| 145 |
+
return True
|
| 146 |
+
else:
|
| 147 |
+
logger.warning(f"No active {provider} account found for user {user_id}")
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Error unlinking social account: {str(e)}", exc_info=True)
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
async def get_profile_picture(user_id: str, preferred_provider: str = None) -> Optional[str]:
|
| 156 |
+
"""Get user's profile picture from social accounts"""
|
| 157 |
+
try:
|
| 158 |
+
query = {"user_id": user_id, "is_active": True}
|
| 159 |
+
|
| 160 |
+
# If preferred provider specified, try that first
|
| 161 |
+
if preferred_provider:
|
| 162 |
+
account = await SocialAccountModel.collection.find_one({
|
| 163 |
+
**query,
|
| 164 |
+
"provider": preferred_provider,
|
| 165 |
+
"picture": {"$exists": True, "$ne": None}
|
| 166 |
+
})
|
| 167 |
+
if account and account.get("picture"):
|
| 168 |
+
return account["picture"]
|
| 169 |
+
|
| 170 |
+
# Otherwise, get any account with a profile picture
|
| 171 |
+
account = await SocialAccountModel.collection.find_one({
|
| 172 |
+
**query,
|
| 173 |
+
"picture": {"$exists": True, "$ne": None}
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
return account.get("picture") if account else None
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"Error getting profile picture for user {user_id}: {str(e)}", exc_info=True)
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
async def get_social_account_summary(user_id: str) -> Dict[str, Any]:
|
| 184 |
+
"""Get summary of all linked social accounts for a user"""
|
| 185 |
+
try:
|
| 186 |
+
accounts = await SocialAccountModel.find_by_user_id(user_id)
|
| 187 |
+
|
| 188 |
+
summary = {
|
| 189 |
+
"linked_accounts": [],
|
| 190 |
+
"total_accounts": len(accounts),
|
| 191 |
+
"profile_picture": None
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
for account in accounts:
|
| 195 |
+
summary["linked_accounts"].append({
|
| 196 |
+
"provider": account["provider"],
|
| 197 |
+
"email": account.get("email"),
|
| 198 |
+
"name": account.get("name"),
|
| 199 |
+
"linked_at": account["created_at"],
|
| 200 |
+
"last_login": account.get("last_login")
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
# Set profile picture if available
|
| 204 |
+
if not summary["profile_picture"] and account.get("picture"):
|
| 205 |
+
summary["profile_picture"] = account["picture"]
|
| 206 |
+
|
| 207 |
+
return summary
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"Error getting social account summary for user {user_id}: {str(e)}", exc_info=True)
|
| 211 |
+
return {"linked_accounts": [], "total_accounts": 0, "profile_picture": None}
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
async def merge_social_accounts(primary_user_id: str, secondary_user_id: str) -> bool:
|
| 215 |
+
"""Merge social accounts from secondary user to primary user"""
|
| 216 |
+
try:
|
| 217 |
+
# Get all social accounts from secondary user
|
| 218 |
+
secondary_accounts = await SocialAccountModel.find_by_user_id(secondary_user_id)
|
| 219 |
+
|
| 220 |
+
for account in secondary_accounts:
|
| 221 |
+
# Check if primary user already has this provider linked
|
| 222 |
+
existing = await SocialAccountModel.collection.find_one({
|
| 223 |
+
"user_id": primary_user_id,
|
| 224 |
+
"provider": account["provider"],
|
| 225 |
+
"is_active": True
|
| 226 |
+
})
|
| 227 |
+
|
| 228 |
+
if not existing:
|
| 229 |
+
# Transfer the account to primary user
|
| 230 |
+
await SocialAccountModel.collection.update_one(
|
| 231 |
+
{"_id": account["_id"]},
|
| 232 |
+
{
|
| 233 |
+
"$set": {
|
| 234 |
+
"user_id": primary_user_id,
|
| 235 |
+
"updated_at": datetime.utcnow()
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
logger.info(f"Transferred {account['provider']} account from user {secondary_user_id} to {primary_user_id}")
|
| 240 |
+
else:
|
| 241 |
+
# Deactivate the secondary account
|
| 242 |
+
await SocialAccountModel.collection.update_one(
|
| 243 |
+
{"_id": account["_id"]},
|
| 244 |
+
{
|
| 245 |
+
"$set": {
|
| 246 |
+
"is_active": False,
|
| 247 |
+
"updated_at": datetime.utcnow()
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
logger.info(f"Deactivated duplicate {account['provider']} account for user {secondary_user_id}")
|
| 252 |
+
|
| 253 |
+
return True
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.error(f"Error merging social accounts: {str(e)}", exc_info=True)
|
| 257 |
+
return False
|
app/models/social_security_model.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
import logging
|
| 3 |
+
from app.core.cache_client import get_redis
|
| 4 |
+
from fastapi import HTTPException
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
class SocialSecurityModel:
|
| 9 |
+
"""Model for handling social login security features"""
|
| 10 |
+
|
| 11 |
+
# Rate limiting constants
|
| 12 |
+
OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour
|
| 13 |
+
OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
|
| 14 |
+
|
| 15 |
+
# Failed attempt tracking
|
| 16 |
+
OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP
|
| 17 |
+
OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes
|
| 18 |
+
OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool:
|
| 22 |
+
"""Check if OAuth rate limit is exceeded for IP and provider"""
|
| 23 |
+
if not client_ip:
|
| 24 |
+
return True # Allow if no IP provided
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
redis = await get_redis()
|
| 28 |
+
rate_key = f"oauth_rate:{client_ip}:{provider}"
|
| 29 |
+
|
| 30 |
+
current_count = await redis.get(rate_key)
|
| 31 |
+
if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX:
|
| 32 |
+
logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True)
|
| 39 |
+
return True # Allow on error to avoid blocking legitimate users
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
async def increment_oauth_rate_limit(client_ip: str, provider: str):
|
| 43 |
+
"""Increment OAuth rate limit counter"""
|
| 44 |
+
if not client_ip:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
redis = await get_redis()
|
| 49 |
+
rate_key = f"oauth_rate:{client_ip}:{provider}"
|
| 50 |
+
|
| 51 |
+
count = await redis.incr(rate_key)
|
| 52 |
+
if count == 1:
|
| 53 |
+
await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW)
|
| 54 |
+
|
| 55 |
+
logger.debug(f"OAuth rate limit count for {client_ip}:{provider} = {count}")
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True)
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
async def track_oauth_failed_attempt(client_ip: str, provider: str):
|
| 62 |
+
"""Track failed OAuth verification attempts"""
|
| 63 |
+
if not client_ip:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
redis = await get_redis()
|
| 68 |
+
failed_key = f"oauth_failed:{client_ip}:{provider}"
|
| 69 |
+
|
| 70 |
+
attempts = await redis.incr(failed_key)
|
| 71 |
+
if attempts == 1:
|
| 72 |
+
await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW)
|
| 73 |
+
|
| 74 |
+
# Lock IP if too many failed attempts
|
| 75 |
+
if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX:
|
| 76 |
+
await SocialSecurityModel.lock_oauth_ip(client_ip, provider)
|
| 77 |
+
logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts")
|
| 78 |
+
|
| 79 |
+
logger.debug(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}")
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True)
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
async def lock_oauth_ip(client_ip: str, provider: str):
|
| 86 |
+
"""Lock IP for OAuth attempts on specific provider"""
|
| 87 |
+
try:
|
| 88 |
+
redis = await get_redis()
|
| 89 |
+
lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
|
| 90 |
+
await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked")
|
| 91 |
+
logger.info(f"IP {client_ip} locked for OAuth provider {provider}")
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool:
|
| 97 |
+
"""Check if IP is locked for OAuth attempts on specific provider"""
|
| 98 |
+
if not client_ip:
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
redis = await get_redis()
|
| 103 |
+
lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
|
| 104 |
+
locked = await redis.get(lock_key)
|
| 105 |
+
return locked is not None
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True)
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
async def clear_oauth_failed_attempts(client_ip: str, provider: str):
|
| 112 |
+
"""Clear failed OAuth attempts on successful verification"""
|
| 113 |
+
if not client_ip:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
redis = await get_redis()
|
| 118 |
+
failed_key = f"oauth_failed:{client_ip}:{provider}"
|
| 119 |
+
await redis.delete(failed_key)
|
| 120 |
+
logger.debug(f"Cleared OAuth failed attempts for {client_ip}:{provider}")
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True)
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
async def validate_oauth_token_format(token: str, provider: str) -> bool:
|
| 126 |
+
"""Basic validation of OAuth token format"""
|
| 127 |
+
if not token or not isinstance(token, str):
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
# Basic length and format checks
|
| 131 |
+
if provider == "google":
|
| 132 |
+
# Google ID tokens are typically JWT format
|
| 133 |
+
return len(token) > 100 and token.count('.') == 2
|
| 134 |
+
elif provider == "apple":
|
| 135 |
+
# Apple ID tokens are also JWT format
|
| 136 |
+
return len(token) > 100 and token.count('.') == 2
|
| 137 |
+
elif provider == "facebook":
|
| 138 |
+
# Facebook access tokens are typically shorter
|
| 139 |
+
return len(token) > 20 and len(token) < 500
|
| 140 |
+
|
| 141 |
+
return True # Allow unknown providers
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
async def log_oauth_attempt(client_ip: str, provider: str, success: bool, user_id: str = None):
|
| 145 |
+
"""Log OAuth authentication attempts for security monitoring"""
|
| 146 |
+
try:
|
| 147 |
+
redis = await get_redis()
|
| 148 |
+
log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}"
|
| 149 |
+
|
| 150 |
+
log_entry = {
|
| 151 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 152 |
+
"ip": client_ip,
|
| 153 |
+
"provider": provider,
|
| 154 |
+
"success": success,
|
| 155 |
+
"user_id": user_id
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# Store as JSON string in Redis list
|
| 159 |
+
import json
|
| 160 |
+
await redis.lpush(log_key, json.dumps(log_entry))
|
| 161 |
+
|
| 162 |
+
# Keep only last 1000 entries per day
|
| 163 |
+
await redis.ltrim(log_key, 0, 999)
|
| 164 |
+
|
| 165 |
+
# Set expiry for 30 days
|
| 166 |
+
await redis.expire(log_key, 30 * 24 * 3600)
|
| 167 |
+
|
| 168 |
+
logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}")
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True)
|
app/routers/account_router.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, Query
|
| 2 |
+
from fastapi.security import HTTPBearer
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from app.schemas.user_schema import (
|
| 8 |
+
LinkSocialAccountRequest, UnlinkSocialAccountRequest,
|
| 9 |
+
SocialAccountSummary, LoginHistoryResponse, SecuritySettingsResponse,
|
| 10 |
+
TokenResponse
|
| 11 |
+
)
|
| 12 |
+
from app.services.account_service import AccountService
|
| 13 |
+
from app.utils.jwt import decode_token
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
router = APIRouter()
|
| 19 |
+
security = HTTPBearer()
|
| 20 |
+
|
| 21 |
+
def get_current_user(token: str = Depends(security)):
|
| 22 |
+
"""Extract user ID from JWT token"""
|
| 23 |
+
try:
|
| 24 |
+
payload = decode_token(token.credentials)
|
| 25 |
+
user_id = payload.get("user_id")
|
| 26 |
+
if not user_id:
|
| 27 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 28 |
+
return user_id
|
| 29 |
+
except Exception as e:
|
| 30 |
+
logger.error(f"Token validation error: {str(e)}")
|
| 31 |
+
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
| 32 |
+
|
| 33 |
+
def get_client_ip(request: Request) -> str:
|
| 34 |
+
"""Extract client IP from request"""
|
| 35 |
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
| 36 |
+
if forwarded_for:
|
| 37 |
+
return forwarded_for.split(",")[0].strip()
|
| 38 |
+
|
| 39 |
+
real_ip = request.headers.get("X-Real-IP")
|
| 40 |
+
if real_ip:
|
| 41 |
+
return real_ip
|
| 42 |
+
|
| 43 |
+
return request.client.host if request.client else "unknown"
|
| 44 |
+
|
| 45 |
+
@router.get("/social-accounts", response_model=SocialAccountSummary)
|
| 46 |
+
async def get_social_accounts(user_id: str = Depends(get_current_user)):
|
| 47 |
+
"""Get all linked social accounts for the current user"""
|
| 48 |
+
try:
|
| 49 |
+
account_service = AccountService()
|
| 50 |
+
summary = await account_service.get_social_account_summary(user_id)
|
| 51 |
+
return summary
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Error fetching social accounts for user {user_id}: {str(e)}")
|
| 54 |
+
raise HTTPException(status_code=500, detail="Failed to fetch social accounts")
|
| 55 |
+
|
| 56 |
+
@router.post("/link-social-account", response_model=dict)
|
| 57 |
+
async def link_social_account(
|
| 58 |
+
request: LinkSocialAccountRequest,
|
| 59 |
+
req: Request,
|
| 60 |
+
user_id: str = Depends(get_current_user)
|
| 61 |
+
):
|
| 62 |
+
"""Link a new social account to the current user"""
|
| 63 |
+
try:
|
| 64 |
+
client_ip = get_client_ip(req)
|
| 65 |
+
account_service = AccountService()
|
| 66 |
+
|
| 67 |
+
result = await account_service.link_social_account(
|
| 68 |
+
user_id=user_id,
|
| 69 |
+
provider=request.provider,
|
| 70 |
+
token=request.token,
|
| 71 |
+
client_ip=client_ip
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return {"message": f"Successfully linked {request.provider} account", "result": result}
|
| 75 |
+
except ValueError as e:
|
| 76 |
+
logger.warning(f"Invalid link request for user {user_id}: {str(e)}")
|
| 77 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"Error linking social account for user {user_id}: {str(e)}")
|
| 80 |
+
raise HTTPException(status_code=500, detail="Failed to link social account")
|
| 81 |
+
|
| 82 |
+
@router.delete("/unlink-social-account", response_model=dict)
|
| 83 |
+
async def unlink_social_account(
|
| 84 |
+
request: UnlinkSocialAccountRequest,
|
| 85 |
+
user_id: str = Depends(get_current_user)
|
| 86 |
+
):
|
| 87 |
+
"""Unlink a social account from the current user"""
|
| 88 |
+
try:
|
| 89 |
+
account_service = AccountService()
|
| 90 |
+
|
| 91 |
+
result = await account_service.unlink_social_account(
|
| 92 |
+
user_id=user_id,
|
| 93 |
+
provider=request.provider
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return {"message": f"Successfully unlinked {request.provider} account", "result": result}
|
| 97 |
+
except ValueError as e:
|
| 98 |
+
logger.warning(f"Invalid unlink request for user {user_id}: {str(e)}")
|
| 99 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error unlinking social account for user {user_id}: {str(e)}")
|
| 102 |
+
raise HTTPException(status_code=500, detail="Failed to unlink social account")
|
| 103 |
+
|
| 104 |
+
@router.get("/login-history", response_model=LoginHistoryResponse)
|
| 105 |
+
async def get_login_history(
|
| 106 |
+
page: int = Query(1, ge=1, description="Page number"),
|
| 107 |
+
per_page: int = Query(10, ge=1, le=50, description="Items per page"),
|
| 108 |
+
days: int = Query(30, ge=1, le=365, description="Number of days to look back"),
|
| 109 |
+
user_id: str = Depends(get_current_user)
|
| 110 |
+
):
|
| 111 |
+
"""Get login history for the current user"""
|
| 112 |
+
try:
|
| 113 |
+
account_service = AccountService()
|
| 114 |
+
|
| 115 |
+
history = await account_service.get_login_history(
|
| 116 |
+
user_id=user_id,
|
| 117 |
+
page=page,
|
| 118 |
+
per_page=per_page,
|
| 119 |
+
days=days
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return history
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"Error fetching login history for user {user_id}: {str(e)}")
|
| 125 |
+
raise HTTPException(status_code=500, detail="Failed to fetch login history")
|
| 126 |
+
|
| 127 |
+
@router.get("/security-settings", response_model=SecuritySettingsResponse)
|
| 128 |
+
async def get_security_settings(user_id: str = Depends(get_current_user)):
|
| 129 |
+
"""Get security settings and status for the current user"""
|
| 130 |
+
try:
|
| 131 |
+
account_service = AccountService()
|
| 132 |
+
|
| 133 |
+
settings = await account_service.get_security_settings(user_id)
|
| 134 |
+
|
| 135 |
+
return settings
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.error(f"Error fetching security settings for user {user_id}: {str(e)}")
|
| 138 |
+
raise HTTPException(status_code=500, detail="Failed to fetch security settings")
|
| 139 |
+
|
| 140 |
+
@router.post("/merge-accounts", response_model=dict)
|
| 141 |
+
async def merge_social_accounts(
|
| 142 |
+
target_user_id: str,
|
| 143 |
+
req: Request,
|
| 144 |
+
user_id: str = Depends(get_current_user)
|
| 145 |
+
):
|
| 146 |
+
"""Merge social accounts from another user (admin function or user-initiated)"""
|
| 147 |
+
try:
|
| 148 |
+
# For security, only allow users to merge their own accounts or implement admin check
|
| 149 |
+
if user_id != target_user_id:
|
| 150 |
+
# In a real implementation, you'd check if the current user is an admin
|
| 151 |
+
# or if they have proper authorization to merge accounts
|
| 152 |
+
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
| 153 |
+
|
| 154 |
+
client_ip = get_client_ip(req)
|
| 155 |
+
account_service = AccountService()
|
| 156 |
+
|
| 157 |
+
result = await account_service.merge_social_accounts(
|
| 158 |
+
primary_user_id=user_id,
|
| 159 |
+
secondary_user_id=target_user_id,
|
| 160 |
+
client_ip=client_ip
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return {"message": "Successfully merged social accounts", "result": result}
|
| 164 |
+
except ValueError as e:
|
| 165 |
+
logger.warning(f"Invalid merge request for user {user_id}: {str(e)}")
|
| 166 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Error merging social accounts for user {user_id}: {str(e)}")
|
| 169 |
+
raise HTTPException(status_code=500, detail="Failed to merge social accounts")
|
| 170 |
+
|
| 171 |
+
@router.delete("/revoke-all-sessions", response_model=dict)
|
| 172 |
+
async def revoke_all_sessions(
|
| 173 |
+
req: Request,
|
| 174 |
+
user_id: str = Depends(get_current_user)
|
| 175 |
+
):
|
| 176 |
+
"""Revoke all active sessions for security purposes"""
|
| 177 |
+
try:
|
| 178 |
+
client_ip = get_client_ip(req)
|
| 179 |
+
account_service = AccountService()
|
| 180 |
+
|
| 181 |
+
result = await account_service.revoke_all_sessions(user_id, client_ip)
|
| 182 |
+
|
| 183 |
+
return {"message": "All sessions have been revoked", "result": result}
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.error(f"Error revoking sessions for user {user_id}: {str(e)}")
|
| 186 |
+
raise HTTPException(status_code=500, detail="Failed to revoke sessions")
|
| 187 |
+
|
| 188 |
+
@router.get("/trusted-devices", response_model=dict)
|
| 189 |
+
async def get_trusted_devices(user_id: str = Depends(get_current_user)):
|
| 190 |
+
"""Get list of trusted devices for the current user"""
|
| 191 |
+
try:
|
| 192 |
+
account_service = AccountService()
|
| 193 |
+
|
| 194 |
+
devices = await account_service.get_trusted_devices(user_id)
|
| 195 |
+
|
| 196 |
+
return {"devices": devices}
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error(f"Error fetching trusted devices for user {user_id}: {str(e)}")
|
| 199 |
+
raise HTTPException(status_code=500, detail="Failed to fetch trusted devices")
|
| 200 |
+
|
| 201 |
+
@router.delete("/trusted-devices/{device_id}", response_model=dict)
|
| 202 |
+
async def remove_trusted_device(
|
| 203 |
+
device_id: str,
|
| 204 |
+
user_id: str = Depends(get_current_user)
|
| 205 |
+
):
|
| 206 |
+
"""Remove a trusted device"""
|
| 207 |
+
try:
|
| 208 |
+
account_service = AccountService()
|
| 209 |
+
|
| 210 |
+
result = await account_service.remove_trusted_device(user_id, device_id)
|
| 211 |
+
|
| 212 |
+
return {"message": "Trusted device removed successfully", "result": result}
|
| 213 |
+
except ValueError as e:
|
| 214 |
+
logger.warning(f"Invalid device removal request for user {user_id}: {str(e)}")
|
| 215 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Error removing trusted device for user {user_id}: {str(e)}")
|
| 218 |
+
raise HTTPException(status_code=500, detail="Failed to remove trusted device")
|
app/routers/user_router.py
CHANGED
|
@@ -12,8 +12,10 @@ from app.schemas.user_schema import (
|
|
| 12 |
)
|
| 13 |
from app.services.user_service import UserService
|
| 14 |
from app.utils.jwt import create_temp_token, decode_token
|
| 15 |
-
from app.utils.social_utils import verify_google_token, verify_apple_token
|
| 16 |
from app.utils.common_utils import validate_identifier
|
|
|
|
|
|
|
| 17 |
import logging
|
| 18 |
|
| 19 |
logger = logging.getLogger("user_router")
|
|
@@ -187,23 +189,84 @@ async def otp_login_handler(
|
|
| 187 |
|
| 188 |
# 🌐 OAuth Login for Google / Apple
|
| 189 |
@router.post("/oauth-login", response_model=TokenResponse)
|
| 190 |
-
async def oauth_login_handler(payload: OAuthLoginRequest):
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# 👤 Final user registration after OTP or OAuth
|
| 209 |
@router.post("/register", response_model=TokenResponse)
|
|
|
|
| 12 |
)
|
| 13 |
from app.services.user_service import UserService
|
| 14 |
from app.utils.jwt import create_temp_token, decode_token
|
| 15 |
+
from app.utils.social_utils import verify_google_token, verify_apple_token, verify_facebook_token
|
| 16 |
from app.utils.common_utils import validate_identifier
|
| 17 |
+
from app.models.social_security_model import SocialSecurityModel
|
| 18 |
+
from fastapi import Request
|
| 19 |
import logging
|
| 20 |
|
| 21 |
logger = logging.getLogger("user_router")
|
|
|
|
| 189 |
|
| 190 |
# 🌐 OAuth Login for Google / Apple
|
| 191 |
@router.post("/oauth-login", response_model=TokenResponse)
|
| 192 |
+
async def oauth_login_handler(payload: OAuthLoginRequest, request: Request):
|
| 193 |
+
from app.core.config import settings
|
| 194 |
+
|
| 195 |
+
# Get client IP
|
| 196 |
+
client_ip = request.client.host if request.client else None
|
| 197 |
+
|
| 198 |
+
# Check if IP is locked for this provider
|
| 199 |
+
if await SocialSecurityModel.is_oauth_ip_locked(client_ip, payload.provider):
|
| 200 |
+
await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
|
| 201 |
+
raise HTTPException(
|
| 202 |
+
status_code=429,
|
| 203 |
+
detail=f"Too many failed attempts. IP temporarily locked for {payload.provider} OAuth."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Check rate limiting
|
| 207 |
+
if not await SocialSecurityModel.check_oauth_rate_limit(client_ip, payload.provider):
|
| 208 |
+
await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
|
| 209 |
+
raise HTTPException(
|
| 210 |
+
status_code=429,
|
| 211 |
+
detail=f"Rate limit exceeded for {payload.provider} OAuth. Please try again later."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Validate token format
|
| 215 |
+
if not await SocialSecurityModel.validate_oauth_token_format(payload.token, payload.provider):
|
| 216 |
+
await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider)
|
| 217 |
+
await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
|
| 218 |
+
raise HTTPException(status_code=400, detail="Invalid token format")
|
| 219 |
+
|
| 220 |
+
# Increment rate limit counter
|
| 221 |
+
await SocialSecurityModel.increment_oauth_rate_limit(client_ip, payload.provider)
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
if payload.provider == "google":
|
| 225 |
+
if not settings.GOOGLE_CLIENT_ID:
|
| 226 |
+
raise HTTPException(status_code=500, detail="Google OAuth not configured")
|
| 227 |
+
user_info = await verify_google_token(payload.token, settings.GOOGLE_CLIENT_ID)
|
| 228 |
+
user_id = f"google_{user_info.get('sub', user_info.get('id'))}"
|
| 229 |
+
|
| 230 |
+
elif payload.provider == "apple":
|
| 231 |
+
if not settings.APPLE_AUDIENCE:
|
| 232 |
+
raise HTTPException(status_code=500, detail="Apple OAuth not configured")
|
| 233 |
+
user_info = await verify_apple_token(payload.token, settings.APPLE_AUDIENCE)
|
| 234 |
+
user_id = f"apple_{user_info.get('sub', user_info.get('id'))}"
|
| 235 |
+
|
| 236 |
+
elif payload.provider == "facebook":
|
| 237 |
+
if not settings.FACEBOOK_APP_ID or not settings.FACEBOOK_APP_SECRET:
|
| 238 |
+
raise HTTPException(status_code=500, detail="Facebook OAuth not configured")
|
| 239 |
+
user_info = await verify_facebook_token(payload.token, settings.FACEBOOK_APP_ID, settings.FACEBOOK_APP_SECRET)
|
| 240 |
+
user_id = f"facebook_{user_info.get('id')}"
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
raise HTTPException(status_code=400, detail="Unsupported OAuth provider")
|
| 244 |
+
|
| 245 |
+
# Clear failed attempts on successful verification
|
| 246 |
+
await SocialSecurityModel.clear_oauth_failed_attempts(client_ip, payload.provider)
|
| 247 |
+
|
| 248 |
+
# Log successful attempt
|
| 249 |
+
await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, True, user_id)
|
| 250 |
+
|
| 251 |
+
temp_token = create_temp_token({
|
| 252 |
+
"sub": user_id,
|
| 253 |
+
"type": "oauth_session",
|
| 254 |
+
"verified": True,
|
| 255 |
+
"provider": payload.provider,
|
| 256 |
+
"user_info": user_info
|
| 257 |
+
}, expires_minutes=10)
|
| 258 |
+
|
| 259 |
+
return {"access_token": temp_token, "token_type": "bearer"}
|
| 260 |
+
|
| 261 |
+
except HTTPException:
|
| 262 |
+
# Re-raise HTTP exceptions (configuration errors, etc.)
|
| 263 |
+
raise
|
| 264 |
+
except Exception as e:
|
| 265 |
+
# Track failed attempt for token verification failures
|
| 266 |
+
await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider)
|
| 267 |
+
await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
|
| 268 |
+
logger.error(f"OAuth verification failed for {payload.provider}: {str(e)}", exc_info=True)
|
| 269 |
+
raise HTTPException(status_code=401, detail="OAuth token verification failed")
|
| 270 |
|
| 271 |
# 👤 Final user registration after OTP or OAuth
|
| 272 |
@router.post("/register", response_model=TokenResponse)
|
app/schemas/user_schema.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
from pydantic import BaseModel, EmailStr, validator
|
| 2 |
-
from typing import Optional, Literal
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
# Used for OTP-based or OAuth-based user registration
|
| 6 |
class UserRegisterRequest(BaseModel):
|
| 7 |
name: str
|
| 8 |
-
email:
|
| 9 |
-
phone:
|
| 10 |
-
otpIdentifer: Optional[str] = None #
|
| 11 |
otp: Optional[str] = None
|
| 12 |
dob: Optional[str] = None # ISO format date string
|
| 13 |
oauth_token: Optional[str] = None
|
| 14 |
-
provider: Optional[Literal["google", "apple"]] = None
|
| 15 |
mode: Literal["otp", "oauth"]
|
| 16 |
|
| 17 |
@validator('phone')
|
|
@@ -108,18 +109,77 @@ class OTPVerifyRequest(BaseModel):
|
|
| 108 |
|
| 109 |
# OAuth login using Google/Apple
|
| 110 |
class OAuthLoginRequest(BaseModel):
|
| 111 |
-
provider: Literal["google", "apple"]
|
| 112 |
token: str
|
| 113 |
|
| 114 |
-
# JWT Token response format
|
| 115 |
class TokenResponse(BaseModel):
|
| 116 |
access_token: str
|
| 117 |
token_type: str = "bearer"
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
#
|
| 121 |
class UserProfileResponse(BaseModel):
|
| 122 |
user_id: str
|
| 123 |
-
|
| 124 |
email: Optional[EmailStr] = None
|
| 125 |
-
phone: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel, EmailStr, validator
|
| 2 |
+
from typing import Optional, Literal, List, Dict, Any
|
| 3 |
+
from datetime import datetime
|
| 4 |
import re
|
| 5 |
|
| 6 |
# Used for OTP-based or OAuth-based user registration
|
| 7 |
class UserRegisterRequest(BaseModel):
|
| 8 |
name: str
|
| 9 |
+
email: EmailStr # Mandatory for all registration modes
|
| 10 |
+
phone: str # Mandatory for all registration modes (always used as OTP identifier)
|
| 11 |
+
otpIdentifer: Optional[str] = None # Deprecated - phone is always the OTP identifier
|
| 12 |
otp: Optional[str] = None
|
| 13 |
dob: Optional[str] = None # ISO format date string
|
| 14 |
oauth_token: Optional[str] = None
|
| 15 |
+
provider: Optional[Literal["google", "apple", "facebook"]] = None
|
| 16 |
mode: Literal["otp", "oauth"]
|
| 17 |
|
| 18 |
@validator('phone')
|
|
|
|
| 109 |
|
| 110 |
# OAuth login using Google/Apple
|
| 111 |
class OAuthLoginRequest(BaseModel):
|
| 112 |
+
provider: Literal["google", "apple", "facebook"]
|
| 113 |
token: str
|
| 114 |
|
| 115 |
+
# JWT Token response format with enhanced security info
|
| 116 |
class TokenResponse(BaseModel):
|
| 117 |
access_token: str
|
| 118 |
token_type: str = "bearer"
|
| 119 |
+
expires_in: Optional[int] = 28800 # 8 hours in seconds
|
| 120 |
+
user_id: Optional[str] = None
|
| 121 |
+
name: Optional[str] = None
|
| 122 |
+
email: Optional[str] = None
|
| 123 |
+
profile_picture: Optional[str] = None
|
| 124 |
+
auth_method: Optional[str] = None # "otp" or "oauth"
|
| 125 |
+
provider: Optional[str] = None # For OAuth logins
|
| 126 |
+
security_info: Optional[Dict[str, Any]] = None
|
| 127 |
|
| 128 |
+
# Enhanced user profile response with social accounts
|
| 129 |
class UserProfileResponse(BaseModel):
|
| 130 |
user_id: str
|
| 131 |
+
name: str
|
| 132 |
email: Optional[EmailStr] = None
|
| 133 |
+
phone: Optional[str] = None
|
| 134 |
+
profile_picture: Optional[str] = None
|
| 135 |
+
auth_method: str
|
| 136 |
+
created_at: datetime
|
| 137 |
+
social_accounts: Optional[List[Dict[str, Any]]] = None
|
| 138 |
+
security_info: Optional[Dict[str, Any]] = None
|
| 139 |
+
|
| 140 |
+
# Social account information
|
| 141 |
+
class SocialAccountInfo(BaseModel):
|
| 142 |
+
provider: str
|
| 143 |
+
email: Optional[str] = None
|
| 144 |
+
name: Optional[str] = None
|
| 145 |
+
linked_at: datetime
|
| 146 |
+
last_login: Optional[datetime] = None
|
| 147 |
+
|
| 148 |
+
# Social account summary response
|
| 149 |
+
class SocialAccountSummary(BaseModel):
|
| 150 |
+
linked_accounts: List[SocialAccountInfo]
|
| 151 |
+
total_accounts: int
|
| 152 |
+
profile_picture: Optional[str] = None
|
| 153 |
+
|
| 154 |
+
# Account linking request
|
| 155 |
+
class LinkSocialAccountRequest(BaseModel):
|
| 156 |
+
provider: Literal["google", "apple", "facebook"]
|
| 157 |
+
token: str
|
| 158 |
+
|
| 159 |
+
# Account unlinking request
|
| 160 |
+
class UnlinkSocialAccountRequest(BaseModel):
|
| 161 |
+
provider: Literal["google", "apple", "facebook"]
|
| 162 |
+
|
| 163 |
+
# Login history entry
|
| 164 |
+
class LoginHistoryEntry(BaseModel):
|
| 165 |
+
timestamp: datetime
|
| 166 |
+
method: str # "otp" or "oauth"
|
| 167 |
+
provider: Optional[str] = None
|
| 168 |
+
ip_address: Optional[str] = None
|
| 169 |
+
success: bool
|
| 170 |
+
device_info: Optional[str] = None
|
| 171 |
+
|
| 172 |
+
# Login history response
|
| 173 |
+
class LoginHistoryResponse(BaseModel):
|
| 174 |
+
entries: List[LoginHistoryEntry]
|
| 175 |
+
total_entries: int
|
| 176 |
+
page: int
|
| 177 |
+
per_page: int
|
| 178 |
+
|
| 179 |
+
# Security settings response
|
| 180 |
+
class SecuritySettingsResponse(BaseModel):
|
| 181 |
+
two_factor_enabled: bool = False
|
| 182 |
+
linked_social_accounts: int
|
| 183 |
+
last_password_change: Optional[datetime] = None
|
| 184 |
+
recent_login_attempts: int
|
| 185 |
+
account_locked: bool = False
|
app/services/account_service.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
from typing import List, Dict, Any, Optional
|
| 3 |
+
import logging
|
| 4 |
+
from bson import ObjectId
|
| 5 |
+
|
| 6 |
+
from app.models.social_account_model import SocialAccountModel
|
| 7 |
+
from app.models.user_model import BookMyServiceUserModel
|
| 8 |
+
from app.schemas.user_schema import (
|
| 9 |
+
SocialAccountSummary, SocialAccountInfo, LoginHistoryResponse,
|
| 10 |
+
LoginHistoryEntry, SecuritySettingsResponse
|
| 11 |
+
)
|
| 12 |
+
from app.utils.social_utils import verify_google_token, verify_apple_token, verify_facebook_token
|
| 13 |
+
from app.core.nosql_client import db
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class AccountService:
|
| 19 |
+
"""Service for managing user accounts, social accounts, and security settings"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.security_collection = db.get_collection("security_logs")
|
| 23 |
+
self.device_collection = db.get_collection("device_tracking")
|
| 24 |
+
self.session_collection = db.get_collection("user_sessions")
|
| 25 |
+
|
| 26 |
+
async def get_social_account_summary(self, user_id: str) -> SocialAccountSummary:
|
| 27 |
+
"""Get summary of all linked social accounts for a user"""
|
| 28 |
+
try:
|
| 29 |
+
social_accounts = await SocialAccountModel.find_by_user_id(user_id)
|
| 30 |
+
|
| 31 |
+
linked_accounts = []
|
| 32 |
+
profile_picture = None
|
| 33 |
+
|
| 34 |
+
for account in social_accounts:
|
| 35 |
+
account_info = SocialAccountInfo(
|
| 36 |
+
provider=account["provider"],
|
| 37 |
+
email=account.get("email"),
|
| 38 |
+
name=account.get("name"),
|
| 39 |
+
linked_at=account["created_at"],
|
| 40 |
+
last_login=account.get("last_login")
|
| 41 |
+
)
|
| 42 |
+
linked_accounts.append(account_info)
|
| 43 |
+
|
| 44 |
+
# Use the first available profile picture
|
| 45 |
+
if not profile_picture and account.get("profile_picture"):
|
| 46 |
+
profile_picture = account["profile_picture"]
|
| 47 |
+
|
| 48 |
+
return SocialAccountSummary(
|
| 49 |
+
linked_accounts=linked_accounts,
|
| 50 |
+
total_accounts=len(linked_accounts),
|
| 51 |
+
profile_picture=profile_picture
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"Error getting social account summary for user {user_id}: {str(e)}")
|
| 56 |
+
raise
|
| 57 |
+
|
| 58 |
+
async def link_social_account(self, user_id: str, provider: str, token: str, client_ip: str) -> Dict[str, Any]:
|
| 59 |
+
"""Link a new social account to an existing user"""
|
| 60 |
+
try:
|
| 61 |
+
# Verify the token and get user info
|
| 62 |
+
user_info = await self._verify_social_token(provider, token)
|
| 63 |
+
|
| 64 |
+
# Check if this social account is already linked to another user
|
| 65 |
+
existing_account = await SocialAccountModel.find_by_provider_id(
|
| 66 |
+
provider, user_info["id"]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if existing_account and existing_account["user_id"] != user_id:
|
| 70 |
+
raise ValueError(f"This {provider} account is already linked to another user")
|
| 71 |
+
|
| 72 |
+
# Check if user already has this provider linked
|
| 73 |
+
user_provider_account = await SocialAccountModel.find_by_user_and_provider(
|
| 74 |
+
user_id, provider
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if user_provider_account:
|
| 78 |
+
# Update existing account
|
| 79 |
+
await SocialAccountModel.update_social_account(
|
| 80 |
+
user_id, provider, user_info, client_ip
|
| 81 |
+
)
|
| 82 |
+
action = "updated"
|
| 83 |
+
else:
|
| 84 |
+
# Create new social account link
|
| 85 |
+
await SocialAccountModel.create_social_account(
|
| 86 |
+
user_id, provider, user_info, client_ip
|
| 87 |
+
)
|
| 88 |
+
action = "linked"
|
| 89 |
+
|
| 90 |
+
# Log the action
|
| 91 |
+
await self._log_account_action(
|
| 92 |
+
user_id, f"social_account_{action}",
|
| 93 |
+
{"provider": provider, "client_ip": client_ip}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return {"action": action, "provider": provider, "user_info": user_info}
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Error linking social account for user {user_id}: {str(e)}")
|
| 100 |
+
raise
|
| 101 |
+
|
| 102 |
+
async def unlink_social_account(self, user_id: str, provider: str) -> Dict[str, Any]:
|
| 103 |
+
"""Unlink a social account from a user"""
|
| 104 |
+
try:
|
| 105 |
+
# Check if account exists
|
| 106 |
+
account = await SocialAccountModel.find_by_user_and_provider(user_id, provider)
|
| 107 |
+
if not account:
|
| 108 |
+
raise ValueError(f"No {provider} account found for this user")
|
| 109 |
+
|
| 110 |
+
# Check if this is the only authentication method
|
| 111 |
+
user = await BookMyServiceUserModel.find_by_id(user_id)
|
| 112 |
+
if not user:
|
| 113 |
+
raise ValueError("User not found")
|
| 114 |
+
|
| 115 |
+
# Count total social accounts
|
| 116 |
+
social_accounts = await SocialAccountModel.find_by_user_id(user_id)
|
| 117 |
+
|
| 118 |
+
# If user has no phone/email and this is their only social account, prevent unlinking
|
| 119 |
+
if (len(social_accounts) == 1 and
|
| 120 |
+
not user.get("phone") and not user.get("email")):
|
| 121 |
+
raise ValueError("Cannot unlink the only authentication method")
|
| 122 |
+
|
| 123 |
+
# Unlink the account
|
| 124 |
+
result = await SocialAccountModel.unlink_social_account(user_id, provider)
|
| 125 |
+
|
| 126 |
+
# Log the action
|
| 127 |
+
await self._log_account_action(
|
| 128 |
+
user_id, "social_account_unlinked",
|
| 129 |
+
{"provider": provider}
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return {"action": "unlinked", "provider": provider, "result": result}
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Error unlinking social account for user {user_id}: {str(e)}")
|
| 136 |
+
raise
|
| 137 |
+
|
| 138 |
+
async def get_login_history(self, user_id: str, page: int = 1,
|
| 139 |
+
per_page: int = 10, days: int = 30) -> LoginHistoryResponse:
|
| 140 |
+
"""Get login history for a user"""
|
| 141 |
+
try:
|
| 142 |
+
# Calculate date range
|
| 143 |
+
end_date = datetime.utcnow()
|
| 144 |
+
start_date = end_date - timedelta(days=days)
|
| 145 |
+
|
| 146 |
+
# Query security logs for login events
|
| 147 |
+
skip = (page - 1) * per_page
|
| 148 |
+
|
| 149 |
+
pipeline = [
|
| 150 |
+
{
|
| 151 |
+
"$match": {
|
| 152 |
+
"user_id": user_id,
|
| 153 |
+
"timestamp": {"$gte": start_date, "$lte": end_date},
|
| 154 |
+
"$or": [
|
| 155 |
+
{"path": {"$regex": "/login"}},
|
| 156 |
+
{"path": {"$regex": "/oauth"}},
|
| 157 |
+
{"path": {"$regex": "/otp"}}
|
| 158 |
+
]
|
| 159 |
+
}
|
| 160 |
+
},
|
| 161 |
+
{"$sort": {"timestamp": -1}},
|
| 162 |
+
{"$skip": skip},
|
| 163 |
+
{"$limit": per_page}
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
cursor = self.security_collection.aggregate(pipeline)
|
| 167 |
+
logs = await cursor.to_list(length=per_page)
|
| 168 |
+
|
| 169 |
+
# Count total entries
|
| 170 |
+
total_count = await self.security_collection.count_documents({
|
| 171 |
+
"user_id": user_id,
|
| 172 |
+
"timestamp": {"$gte": start_date, "$lte": end_date},
|
| 173 |
+
"$or": [
|
| 174 |
+
{"path": {"$regex": "/login"}},
|
| 175 |
+
{"path": {"$regex": "/oauth"}},
|
| 176 |
+
{"path": {"$regex": "/otp"}}
|
| 177 |
+
]
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
# Convert to response format
|
| 181 |
+
entries = []
|
| 182 |
+
for log in logs:
|
| 183 |
+
method = "oauth" if "oauth" in log["path"] else "otp"
|
| 184 |
+
provider = None
|
| 185 |
+
|
| 186 |
+
# Extract provider from query params if available
|
| 187 |
+
if method == "oauth" and log.get("query_params"):
|
| 188 |
+
provider = log["query_params"].get("provider")
|
| 189 |
+
|
| 190 |
+
entry = LoginHistoryEntry(
|
| 191 |
+
timestamp=log["timestamp"],
|
| 192 |
+
method=method,
|
| 193 |
+
provider=provider,
|
| 194 |
+
ip_address=log.get("client_ip"),
|
| 195 |
+
success=log["status_code"] < 400,
|
| 196 |
+
device_info=log.get("device_info", {}).get("user_agent")
|
| 197 |
+
)
|
| 198 |
+
entries.append(entry)
|
| 199 |
+
|
| 200 |
+
return LoginHistoryResponse(
|
| 201 |
+
entries=entries,
|
| 202 |
+
total_entries=total_count,
|
| 203 |
+
page=page,
|
| 204 |
+
per_page=per_page
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Error getting login history for user {user_id}: {str(e)}")
|
| 209 |
+
raise
|
| 210 |
+
|
| 211 |
+
async def get_security_settings(self, user_id: str) -> SecuritySettingsResponse:
|
| 212 |
+
"""Get security settings and status for a user"""
|
| 213 |
+
try:
|
| 214 |
+
# Get user info
|
| 215 |
+
user = await BookMyServiceUserModel.find_by_id(user_id)
|
| 216 |
+
if not user:
|
| 217 |
+
raise ValueError("User not found")
|
| 218 |
+
|
| 219 |
+
# Count linked social accounts
|
| 220 |
+
social_accounts = await SocialAccountModel.find_by_user_id(user_id)
|
| 221 |
+
linked_accounts_count = len(social_accounts)
|
| 222 |
+
|
| 223 |
+
# Get recent login attempts (last 24 hours)
|
| 224 |
+
yesterday = datetime.utcnow() - timedelta(days=1)
|
| 225 |
+
recent_attempts = await self.security_collection.count_documents({
|
| 226 |
+
"user_id": user_id,
|
| 227 |
+
"timestamp": {"$gte": yesterday},
|
| 228 |
+
"$or": [
|
| 229 |
+
{"path": {"$regex": "/login"}},
|
| 230 |
+
{"path": {"$regex": "/oauth"}},
|
| 231 |
+
{"path": {"$regex": "/otp"}}
|
| 232 |
+
]
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Check if account is locked (this would be implemented based on your locking logic)
|
| 236 |
+
account_locked = False # Implement based on your account locking mechanism
|
| 237 |
+
|
| 238 |
+
return SecuritySettingsResponse(
|
| 239 |
+
two_factor_enabled=False, # Implement 2FA if needed
|
| 240 |
+
linked_social_accounts=linked_accounts_count,
|
| 241 |
+
last_password_change=None, # Implement if you have password functionality
|
| 242 |
+
recent_login_attempts=recent_attempts,
|
| 243 |
+
account_locked=account_locked
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
logger.error(f"Error getting security settings for user {user_id}: {str(e)}")
|
| 248 |
+
raise
|
| 249 |
+
|
| 250 |
+
async def merge_social_accounts(self, primary_user_id: str, secondary_user_id: str,
|
| 251 |
+
client_ip: str) -> Dict[str, Any]:
|
| 252 |
+
"""Merge social accounts from secondary user to primary user"""
|
| 253 |
+
try:
|
| 254 |
+
# Get social accounts from secondary user
|
| 255 |
+
secondary_accounts = await SocialAccountModel.find_by_user_id(secondary_user_id)
|
| 256 |
+
|
| 257 |
+
merged_count = 0
|
| 258 |
+
for account in secondary_accounts:
|
| 259 |
+
# Check if primary user already has this provider
|
| 260 |
+
existing = await SocialAccountModel.find_by_user_and_provider(
|
| 261 |
+
primary_user_id, account["provider"]
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if not existing:
|
| 265 |
+
# Transfer the account to primary user
|
| 266 |
+
await SocialAccountModel.update_user_id(
|
| 267 |
+
account["_id"], primary_user_id
|
| 268 |
+
)
|
| 269 |
+
merged_count += 1
|
| 270 |
+
|
| 271 |
+
# Log the merge action
|
| 272 |
+
await self._log_account_action(
|
| 273 |
+
primary_user_id, "accounts_merged",
|
| 274 |
+
{
|
| 275 |
+
"secondary_user_id": secondary_user_id,
|
| 276 |
+
"merged_accounts": merged_count,
|
| 277 |
+
"client_ip": client_ip
|
| 278 |
+
}
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
"merged_accounts": merged_count,
|
| 283 |
+
"primary_user_id": primary_user_id,
|
| 284 |
+
"secondary_user_id": secondary_user_id
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"Error merging accounts {secondary_user_id} -> {primary_user_id}: {str(e)}")
|
| 289 |
+
raise
|
| 290 |
+
|
| 291 |
+
async def revoke_all_sessions(self, user_id: str, client_ip: str) -> Dict[str, Any]:
|
| 292 |
+
"""Revoke all active sessions for a user"""
|
| 293 |
+
try:
|
| 294 |
+
# In a real implementation, you'd have a sessions collection
|
| 295 |
+
# For now, we'll just log the action
|
| 296 |
+
await self._log_account_action(
|
| 297 |
+
user_id, "all_sessions_revoked",
|
| 298 |
+
{"client_ip": client_ip}
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Here you would typically:
|
| 302 |
+
# 1. Delete all session tokens from database
|
| 303 |
+
# 2. Add tokens to a blacklist
|
| 304 |
+
# 3. Force re-authentication on next request
|
| 305 |
+
|
| 306 |
+
return {"action": "revoked", "user_id": user_id}
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
logger.error(f"Error revoking sessions for user {user_id}: {str(e)}")
|
| 310 |
+
raise
|
| 311 |
+
|
| 312 |
+
async def get_trusted_devices(self, user_id: str) -> List[Dict[str, Any]]:
|
| 313 |
+
"""Get list of trusted devices for a user"""
|
| 314 |
+
try:
|
| 315 |
+
cursor = self.device_collection.find({
|
| 316 |
+
"user_id": user_id,
|
| 317 |
+
"is_trusted": True
|
| 318 |
+
}).sort("last_seen", -1)
|
| 319 |
+
|
| 320 |
+
devices = await cursor.to_list(length=None)
|
| 321 |
+
|
| 322 |
+
# Format device information
|
| 323 |
+
trusted_devices = []
|
| 324 |
+
for device in devices:
|
| 325 |
+
device_info = {
|
| 326 |
+
"device_id": str(device["_id"]),
|
| 327 |
+
"device_fingerprint": device["device_fingerprint"],
|
| 328 |
+
"platform": device.get("device_info", {}).get("platform", "Unknown"),
|
| 329 |
+
"browser": device.get("device_info", {}).get("browser", "Unknown"),
|
| 330 |
+
"first_seen": device["first_seen"],
|
| 331 |
+
"last_seen": device["last_seen"],
|
| 332 |
+
"access_count": device.get("access_count", 0)
|
| 333 |
+
}
|
| 334 |
+
trusted_devices.append(device_info)
|
| 335 |
+
|
| 336 |
+
return trusted_devices
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Error getting trusted devices for user {user_id}: {str(e)}")
|
| 340 |
+
raise
|
| 341 |
+
|
| 342 |
+
async def remove_trusted_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
| 343 |
+
"""Remove a trusted device"""
|
| 344 |
+
try:
|
| 345 |
+
result = await self.device_collection.update_one(
|
| 346 |
+
{
|
| 347 |
+
"_id": ObjectId(device_id),
|
| 348 |
+
"user_id": user_id
|
| 349 |
+
},
|
| 350 |
+
{"$set": {"is_trusted": False}}
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if result.matched_count == 0:
|
| 354 |
+
raise ValueError("Device not found or not owned by user")
|
| 355 |
+
|
| 356 |
+
await self._log_account_action(
|
| 357 |
+
user_id, "trusted_device_removed",
|
| 358 |
+
{"device_id": device_id}
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return {"action": "removed", "device_id": device_id}
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"Error removing trusted device for user {user_id}: {str(e)}")
|
| 365 |
+
raise
|
| 366 |
+
|
| 367 |
+
async def _verify_social_token(self, provider: str, token: str) -> Dict[str, Any]:
|
| 368 |
+
"""Verify social media token and return user info"""
|
| 369 |
+
try:
|
| 370 |
+
if provider == "google":
|
| 371 |
+
return await verify_google_token(token)
|
| 372 |
+
elif provider == "apple":
|
| 373 |
+
return await verify_apple_token(token)
|
| 374 |
+
elif provider == "facebook":
|
| 375 |
+
return await verify_facebook_token(token)
|
| 376 |
+
else:
|
| 377 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 378 |
+
except Exception as e:
|
| 379 |
+
logger.error(f"Token verification failed for {provider}: {str(e)}")
|
| 380 |
+
raise ValueError(f"Invalid {provider} token")
|
| 381 |
+
|
| 382 |
+
async def _log_account_action(self, user_id: str, action: str, details: Dict[str, Any]):
|
| 383 |
+
"""Log account-related actions for audit purposes"""
|
| 384 |
+
try:
|
| 385 |
+
log_entry = {
|
| 386 |
+
"timestamp": datetime.utcnow(),
|
| 387 |
+
"user_id": user_id,
|
| 388 |
+
"action": action,
|
| 389 |
+
"details": details,
|
| 390 |
+
"type": "account_management"
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
await self.security_collection.insert_one(log_entry)
|
| 394 |
+
|
| 395 |
+
except Exception as e:
|
| 396 |
+
logger.error(f"Failed to log account action: {str(e)}")
|
app/services/user_service.py
CHANGED
|
@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
|
|
| 4 |
from fastapi import HTTPException
|
| 5 |
from app.models.user_model import BookMyServiceUserModel
|
| 6 |
from app.models.otp_model import BookMyServiceOTPModel
|
|
|
|
| 7 |
from app.core.config import settings
|
| 8 |
from app.utils.common_utils import is_email, validate_identifier
|
| 9 |
from app.schemas.user_schema import UserRegisterRequest
|
|
@@ -11,16 +12,27 @@ import logging
|
|
| 11 |
|
| 12 |
logger = logging.getLogger("user_service")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
class UserService:
|
| 15 |
@staticmethod
|
| 16 |
-
async def send_otp(identifier: str, phone: str = None):
|
| 17 |
-
logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}")
|
| 18 |
|
| 19 |
try:
|
| 20 |
# Validate identifier format
|
| 21 |
identifier_type = validate_identifier(identifier)
|
| 22 |
logger.debug(f"Identifier type: {identifier_type}")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# For phone identifiers, use the identifier itself as phone
|
| 25 |
# For email identifiers, use the provided phone parameter
|
| 26 |
if identifier_type == "phone":
|
|
@@ -31,13 +43,19 @@ class UserService:
|
|
| 31 |
# If email identifier but no phone provided, we'll send OTP via email
|
| 32 |
phone_number = None
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
logger.info(f"OTP stored successfully for identifier: {identifier}")
|
| 40 |
-
logger.debug(f"OTP sent to {identifier}
|
| 41 |
|
| 42 |
except ValueError as ve:
|
| 43 |
logger.error(f"Validation error for identifier {identifier}: {str(ve)}")
|
|
@@ -47,23 +65,32 @@ class UserService:
|
|
| 47 |
raise HTTPException(status_code=500, detail="Failed to send OTP")
|
| 48 |
|
| 49 |
@staticmethod
|
| 50 |
-
async def otp_login_handler(identifier: str, otp: str):
|
| 51 |
-
logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}")
|
| 52 |
|
| 53 |
try:
|
| 54 |
# Validate identifier format
|
| 55 |
identifier_type = validate_identifier(identifier)
|
| 56 |
logger.debug(f"Identifier type: {identifier_type}")
|
| 57 |
|
| 58 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
logger.debug(f"Verifying OTP for identifier: {identifier}")
|
| 60 |
-
otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp)
|
| 61 |
logger.debug(f"OTP verification result: {otp_valid}")
|
| 62 |
|
| 63 |
if not otp_valid:
|
| 64 |
logger.warning(f"Invalid or expired OTP for identifier: {identifier}")
|
|
|
|
|
|
|
| 65 |
raise HTTPException(status_code=400, detail="Invalid or expired OTP")
|
| 66 |
|
|
|
|
|
|
|
| 67 |
logger.info(f"OTP verification successful for identifier: {identifier}")
|
| 68 |
|
| 69 |
# Find user by identifier
|
|
@@ -108,17 +135,28 @@ class UserService:
|
|
| 108 |
async def register(data: UserRegisterRequest, decoded):
|
| 109 |
logger.info(f"Registering user with data: {data}")
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
identifier_type = validate_identifier(identifier)
|
|
|
|
|
|
|
| 119 |
logger.debug(f"Registration identifier type: {identifier_type}")
|
| 120 |
except ValueError as ve:
|
| 121 |
-
logger.error(f"Invalid
|
| 122 |
raise HTTPException(status_code=400, detail=str(ve))
|
| 123 |
|
| 124 |
redis_key = f"bms_otp:{identifier}"
|
|
@@ -133,9 +171,44 @@ class UserService:
|
|
| 133 |
user_id = f"otp_{identifier}"
|
| 134 |
|
| 135 |
elif data.mode == "oauth":
|
|
|
|
| 136 |
if not data.oauth_token or not data.provider:
|
| 137 |
-
raise HTTPException(status_code=400, detail="OAuth token and provider required")
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
else:
|
| 140 |
raise HTTPException(status_code=400, detail="Unsupported registration mode")
|
| 141 |
|
|
@@ -151,6 +224,7 @@ class UserService:
|
|
| 151 |
if existing_user:
|
| 152 |
raise HTTPException(status_code=409, detail="User with this email or phone already exists")
|
| 153 |
|
|
|
|
| 154 |
user_doc = {
|
| 155 |
"user_id": user_id,
|
| 156 |
"name": data.name,
|
|
@@ -159,7 +233,20 @@ class UserService:
|
|
| 159 |
"auth_mode": data.mode,
|
| 160 |
"created_at": datetime.utcnow()
|
| 161 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
await BookMyServiceUserModel.collection.insert_one(user_doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
token_data = {
|
| 165 |
"sub": user_id,
|
|
|
|
| 4 |
from fastapi import HTTPException
|
| 5 |
from app.models.user_model import BookMyServiceUserModel
|
| 6 |
from app.models.otp_model import BookMyServiceOTPModel
|
| 7 |
+
from app.models.social_account_model import SocialAccountModel
|
| 8 |
from app.core.config import settings
|
| 9 |
from app.utils.common_utils import is_email, validate_identifier
|
| 10 |
from app.schemas.user_schema import UserRegisterRequest
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger("user_service")
|
| 14 |
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
class UserService:
|
| 19 |
@staticmethod
|
| 20 |
+
async def send_otp(identifier: str, phone: str = None, client_ip: str = None):
|
| 21 |
+
logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}, ip: {client_ip}")
|
| 22 |
|
| 23 |
try:
|
| 24 |
# Validate identifier format
|
| 25 |
identifier_type = validate_identifier(identifier)
|
| 26 |
logger.debug(f"Identifier type: {identifier_type}")
|
| 27 |
|
| 28 |
+
# Enhanced rate limiting by IP and identifier
|
| 29 |
+
if client_ip:
|
| 30 |
+
ip_rate_key = f"otp_ip_rate:{client_ip}"
|
| 31 |
+
ip_attempts = await BookMyServiceOTPModel.get_rate_limit_count(ip_rate_key)
|
| 32 |
+
if ip_attempts >= 10: # Max 10 OTPs per IP per hour
|
| 33 |
+
logger.warning(f"IP rate limit exceeded for {client_ip}")
|
| 34 |
+
raise HTTPException(status_code=429, detail="Too many OTP requests from this IP")
|
| 35 |
+
|
| 36 |
# For phone identifiers, use the identifier itself as phone
|
| 37 |
# For email identifiers, use the provided phone parameter
|
| 38 |
if identifier_type == "phone":
|
|
|
|
| 43 |
# If email identifier but no phone provided, we'll send OTP via email
|
| 44 |
phone_number = None
|
| 45 |
|
| 46 |
+
# Generate secure OTP (6 digits, cryptographically secure)
|
| 47 |
+
import secrets
|
| 48 |
+
otp = ''.join([str(secrets.randbelow(10)) for _ in range(6)])
|
| 49 |
+
logger.debug(f"Generated secure OTP for identifier: {identifier}")
|
| 50 |
|
| 51 |
await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp)
|
| 52 |
+
|
| 53 |
+
# Track IP-based rate limiting
|
| 54 |
+
if client_ip:
|
| 55 |
+
await BookMyServiceOTPModel.increment_rate_limit(ip_rate_key, 3600) # 1 hour window
|
| 56 |
+
|
| 57 |
logger.info(f"OTP stored successfully for identifier: {identifier}")
|
| 58 |
+
logger.debug(f"OTP sent to {identifier}")
|
| 59 |
|
| 60 |
except ValueError as ve:
|
| 61 |
logger.error(f"Validation error for identifier {identifier}: {str(ve)}")
|
|
|
|
| 65 |
raise HTTPException(status_code=500, detail="Failed to send OTP")
|
| 66 |
|
| 67 |
@staticmethod
|
| 68 |
+
async def otp_login_handler(identifier: str, otp: str, client_ip: str = None):
|
| 69 |
+
logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}, ip: {client_ip}")
|
| 70 |
|
| 71 |
try:
|
| 72 |
# Validate identifier format
|
| 73 |
identifier_type = validate_identifier(identifier)
|
| 74 |
logger.debug(f"Identifier type: {identifier_type}")
|
| 75 |
|
| 76 |
+
# Check if account is locked
|
| 77 |
+
if await BookMyServiceOTPModel.is_account_locked(identifier):
|
| 78 |
+
logger.warning(f"Account locked for identifier: {identifier}")
|
| 79 |
+
raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
|
| 80 |
+
|
| 81 |
+
# Verify OTP with client IP tracking
|
| 82 |
logger.debug(f"Verifying OTP for identifier: {identifier}")
|
| 83 |
+
otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp, client_ip)
|
| 84 |
logger.debug(f"OTP verification result: {otp_valid}")
|
| 85 |
|
| 86 |
if not otp_valid:
|
| 87 |
logger.warning(f"Invalid or expired OTP for identifier: {identifier}")
|
| 88 |
+
# Track failed attempt
|
| 89 |
+
await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
|
| 90 |
raise HTTPException(status_code=400, detail="Invalid or expired OTP")
|
| 91 |
|
| 92 |
+
# Clear failed attempts on successful verification
|
| 93 |
+
await BookMyServiceOTPModel.clear_failed_attempts(identifier)
|
| 94 |
logger.info(f"OTP verification successful for identifier: {identifier}")
|
| 95 |
|
| 96 |
# Find user by identifier
|
|
|
|
| 135 |
async def register(data: UserRegisterRequest, decoded):
|
| 136 |
logger.info(f"Registering user with data: {data}")
|
| 137 |
|
| 138 |
+
# Validate mandatory fields for all registration modes
|
| 139 |
+
if not data.name or not data.name.strip():
|
| 140 |
+
raise HTTPException(status_code=400, detail="Name is required")
|
| 141 |
+
|
| 142 |
+
if not data.email:
|
| 143 |
+
raise HTTPException(status_code=400, detail="Email is required")
|
| 144 |
+
|
| 145 |
+
if not data.phone or not data.phone.strip():
|
| 146 |
+
raise HTTPException(status_code=400, detail="Phone is required")
|
| 147 |
|
| 148 |
+
if data.mode == "otp":
|
| 149 |
+
# Always use phone as the OTP identifier as per documentation
|
| 150 |
+
identifier = data.phone
|
| 151 |
+
|
| 152 |
+
# Validate phone format
|
| 153 |
try:
|
| 154 |
identifier_type = validate_identifier(identifier)
|
| 155 |
+
if identifier_type != "phone":
|
| 156 |
+
raise ValueError("Phone number format is invalid")
|
| 157 |
logger.debug(f"Registration identifier type: {identifier_type}")
|
| 158 |
except ValueError as ve:
|
| 159 |
+
logger.error(f"Invalid phone format during registration: {str(ve)}")
|
| 160 |
raise HTTPException(status_code=400, detail=str(ve))
|
| 161 |
|
| 162 |
redis_key = f"bms_otp:{identifier}"
|
|
|
|
| 171 |
user_id = f"otp_{identifier}"
|
| 172 |
|
| 173 |
elif data.mode == "oauth":
|
| 174 |
+
# Validate OAuth-specific mandatory fields
|
| 175 |
if not data.oauth_token or not data.provider:
|
| 176 |
+
raise HTTPException(status_code=400, detail="OAuth token and provider are required")
|
| 177 |
+
|
| 178 |
+
# Extract user info from decoded token
|
| 179 |
+
user_info = decoded.get("user_info", {})
|
| 180 |
+
provider_user_id = user_info.get("sub") or user_info.get("id")
|
| 181 |
+
|
| 182 |
+
if not provider_user_id:
|
| 183 |
+
raise HTTPException(status_code=400, detail="Invalid OAuth user information")
|
| 184 |
+
|
| 185 |
+
# Check if this social account already exists
|
| 186 |
+
existing_social_account = await SocialAccountModel.find_by_provider_and_user_id(
|
| 187 |
+
data.provider, provider_user_id
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if existing_social_account:
|
| 191 |
+
# User already has this social account linked
|
| 192 |
+
existing_user = await BookMyServiceUserModel.collection.find_one({
|
| 193 |
+
"user_id": existing_social_account["user_id"]
|
| 194 |
+
})
|
| 195 |
+
if existing_user:
|
| 196 |
+
# Update social account with latest info and return existing user token
|
| 197 |
+
await SocialAccountModel.update_social_account(data.provider, provider_user_id, user_info)
|
| 198 |
+
|
| 199 |
+
token_data = {
|
| 200 |
+
"sub": existing_user["user_id"],
|
| 201 |
+
"user_id": existing_user["user_id"],
|
| 202 |
+
"email": existing_user.get("email"),
|
| 203 |
+
"phone": existing_user.get("phone"),
|
| 204 |
+
"role": "user",
|
| 205 |
+
"exp": datetime.utcnow() + timedelta(hours=8)
|
| 206 |
+
}
|
| 207 |
+
access_token = jwt.encode(token_data, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 208 |
+
return {"access_token": access_token}
|
| 209 |
+
|
| 210 |
+
user_id = f"{data.provider}_{provider_user_id}"
|
| 211 |
+
|
| 212 |
else:
|
| 213 |
raise HTTPException(status_code=400, detail="Unsupported registration mode")
|
| 214 |
|
|
|
|
| 224 |
if existing_user:
|
| 225 |
raise HTTPException(status_code=409, detail="User with this email or phone already exists")
|
| 226 |
|
| 227 |
+
# Create user document
|
| 228 |
user_doc = {
|
| 229 |
"user_id": user_id,
|
| 230 |
"name": data.name,
|
|
|
|
| 233 |
"auth_mode": data.mode,
|
| 234 |
"created_at": datetime.utcnow()
|
| 235 |
}
|
| 236 |
+
|
| 237 |
+
# Add profile picture from social account if available
|
| 238 |
+
if data.mode == "oauth" and user_info.get("picture"):
|
| 239 |
+
user_doc["profile_picture"] = user_info["picture"]
|
| 240 |
+
|
| 241 |
await BookMyServiceUserModel.collection.insert_one(user_doc)
|
| 242 |
+
logger.info(f"Created new user: {user_id}")
|
| 243 |
+
|
| 244 |
+
# Create social account record for OAuth registration
|
| 245 |
+
if data.mode == "oauth":
|
| 246 |
+
await SocialAccountModel.create_social_account(
|
| 247 |
+
user_id, data.provider, provider_user_id, user_info
|
| 248 |
+
)
|
| 249 |
+
logger.info(f"Created social account link for {data.provider}")
|
| 250 |
|
| 251 |
token_data = {
|
| 252 |
"sub": user_id,
|
app/utils/social_utils.py
CHANGED
|
@@ -15,6 +15,62 @@ class TokenVerificationError(Exception):
|
|
| 15 |
"""Custom exception for token verification errors"""
|
| 16 |
pass
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class GoogleTokenVerifier:
|
| 19 |
def __init__(self, client_id: str):
|
| 20 |
self.client_id = client_id
|
|
@@ -140,9 +196,11 @@ class AppleTokenVerifier:
|
|
| 140 |
|
| 141 |
# Factory class for easier usage
|
| 142 |
class OAuthVerifier:
|
| 143 |
-
def __init__(self, google_client_id: Optional[str] = None, apple_audience: Optional[str] = None
|
|
|
|
| 144 |
self.google_verifier = GoogleTokenVerifier(google_client_id) if google_client_id else None
|
| 145 |
self.apple_verifier = AppleTokenVerifier(apple_audience) if apple_audience else None
|
|
|
|
| 146 |
|
| 147 |
async def verify_google_token(self, token: str) -> Dict:
|
| 148 |
if not self.google_verifier:
|
|
@@ -153,6 +211,11 @@ class OAuthVerifier:
|
|
| 153 |
if not self.apple_verifier:
|
| 154 |
raise TokenVerificationError("Apple verifier not configured")
|
| 155 |
return await self.apple_verifier.verify_token(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# Convenience functions (backward compatibility)
|
| 158 |
async def verify_google_token(token: str, client_id: str) -> Dict:
|
|
@@ -169,6 +232,13 @@ async def verify_apple_token(token: str, audience: str) -> Dict:
|
|
| 169 |
verifier = AppleTokenVerifier(audience)
|
| 170 |
return await verifier.verify_token(token)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
# Example usage
|
| 173 |
async def example_usage():
|
| 174 |
# Initialize verifier
|
|
|
|
| 15 |
"""Custom exception for token verification errors"""
|
| 16 |
pass
|
| 17 |
|
| 18 |
+
class FacebookTokenVerifier:
|
| 19 |
+
def __init__(self, app_id: str, app_secret: str):
|
| 20 |
+
self.app_id = app_id
|
| 21 |
+
self.app_secret = app_secret
|
| 22 |
+
|
| 23 |
+
async def verify_token(self, token: str) -> Dict:
|
| 24 |
+
"""
|
| 25 |
+
Asynchronously verifies a Facebook access token and returns user data.
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
# First, verify the token with Facebook's debug endpoint
|
| 29 |
+
async with httpx.AsyncClient(timeout=10.0) as client:
|
| 30 |
+
# Verify token validity
|
| 31 |
+
debug_url = f"https://graph.facebook.com/debug_token"
|
| 32 |
+
debug_params = {
|
| 33 |
+
"input_token": token,
|
| 34 |
+
"access_token": f"{self.app_id}|{self.app_secret}"
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
debug_response = await client.get(debug_url, params=debug_params)
|
| 38 |
+
debug_response.raise_for_status()
|
| 39 |
+
debug_data = debug_response.json()
|
| 40 |
+
|
| 41 |
+
if not debug_data.get("data", {}).get("is_valid"):
|
| 42 |
+
raise TokenVerificationError("Invalid Facebook token")
|
| 43 |
+
|
| 44 |
+
# Check if token is for our app
|
| 45 |
+
token_app_id = debug_data.get("data", {}).get("app_id")
|
| 46 |
+
if token_app_id != self.app_id:
|
| 47 |
+
raise TokenVerificationError("Token not for this app")
|
| 48 |
+
|
| 49 |
+
# Get user data
|
| 50 |
+
user_url = "https://graph.facebook.com/me"
|
| 51 |
+
user_params = {
|
| 52 |
+
"access_token": token,
|
| 53 |
+
"fields": "id,name,email,picture.type(large)"
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
user_response = await client.get(user_url, params=user_params)
|
| 57 |
+
user_response.raise_for_status()
|
| 58 |
+
user_data = user_response.json()
|
| 59 |
+
|
| 60 |
+
# Validate required fields
|
| 61 |
+
if not user_data.get("id"):
|
| 62 |
+
raise TokenVerificationError("Missing user ID in Facebook response")
|
| 63 |
+
|
| 64 |
+
logger.info(f"Successfully verified Facebook token for user: {user_data.get('email', user_data.get('id'))}")
|
| 65 |
+
return user_data
|
| 66 |
+
|
| 67 |
+
except httpx.RequestError as e:
|
| 68 |
+
logger.error(f"Facebook token verification request failed: {str(e)}")
|
| 69 |
+
raise TokenVerificationError(f"Facebook API request failed: {str(e)}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Facebook token verification failed: {str(e)}")
|
| 72 |
+
raise TokenVerificationError(f"Invalid Facebook token: {str(e)}")
|
| 73 |
+
|
| 74 |
class GoogleTokenVerifier:
|
| 75 |
def __init__(self, client_id: str):
|
| 76 |
self.client_id = client_id
|
|
|
|
| 196 |
|
| 197 |
# Factory class for easier usage
|
| 198 |
class OAuthVerifier:
|
| 199 |
+
def __init__(self, google_client_id: Optional[str] = None, apple_audience: Optional[str] = None,
|
| 200 |
+
facebook_app_id: Optional[str] = None, facebook_app_secret: Optional[str] = None):
|
| 201 |
self.google_verifier = GoogleTokenVerifier(google_client_id) if google_client_id else None
|
| 202 |
self.apple_verifier = AppleTokenVerifier(apple_audience) if apple_audience else None
|
| 203 |
+
self.facebook_verifier = FacebookTokenVerifier(facebook_app_id, facebook_app_secret) if facebook_app_id and facebook_app_secret else None
|
| 204 |
|
| 205 |
async def verify_google_token(self, token: str) -> Dict:
|
| 206 |
if not self.google_verifier:
|
|
|
|
| 211 |
if not self.apple_verifier:
|
| 212 |
raise TokenVerificationError("Apple verifier not configured")
|
| 213 |
return await self.apple_verifier.verify_token(token)
|
| 214 |
+
|
| 215 |
+
async def verify_facebook_token(self, token: str) -> Dict:
|
| 216 |
+
if not self.facebook_verifier:
|
| 217 |
+
raise TokenVerificationError("Facebook verifier not configured")
|
| 218 |
+
return await self.facebook_verifier.verify_token(token)
|
| 219 |
|
| 220 |
# Convenience functions (backward compatibility)
|
| 221 |
async def verify_google_token(token: str, client_id: str) -> Dict:
|
|
|
|
| 232 |
verifier = AppleTokenVerifier(audience)
|
| 233 |
return await verifier.verify_token(token)
|
| 234 |
|
| 235 |
+
async def verify_facebook_token(token: str, app_id: str, app_secret: str) -> Dict:
|
| 236 |
+
"""
|
| 237 |
+
Asynchronously verifies a Facebook access token and returns user data.
|
| 238 |
+
"""
|
| 239 |
+
verifier = FacebookTokenVerifier(app_id, app_secret)
|
| 240 |
+
return await verifier.verify_token(token)
|
| 241 |
+
|
| 242 |
# Example usage
|
| 243 |
async def example_usage():
|
| 244 |
# Initialize verifier
|