MukeshKapoor25 commited on
Commit
d846c9a
·
1 Parent(s): 91dc3b5

feat(auth): implement enhanced social login with security middleware

Browse files

- Add Facebook OAuth support with token verification
- Implement rate limiting and IP-based security for OAuth logins
- Create account management endpoints for social account linking
- Add security middleware for request logging and device tracking
- Enhance OTP verification with account locking and rate limiting
- Update user schemas with social account and security fields

.env CHANGED
@@ -5,11 +5,12 @@ DB_NAME=book-my-service
5
 
6
  DATABASE_URI=postgresql+asyncpg://trans_owner:BookMyService7@ep-sweet-surf-a1qeduoy.ap-southeast-1.aws.neon.tech/bookmyservice?options=-csearch_path%3Dtrans
7
 
8
- CACHE_URI=redis-11382.c305.ap-south-1-1.ec2.redns.redis-cloud.com:11382
9
 
10
  #CACHE_URI=redis-11521.crce182.ap-south-1-1.ec2.redns.redis-cloud.com:11521
11
 
12
- CACHE_K=dLRZrhU1d5EP9N1CW6grUgsj7MyWIj2i
 
13
 
14
 
15
  RAZORPAY_KEY_ID=rzp_test_2UTAol2AFSV5VN
 
5
 
6
  DATABASE_URI=postgresql+asyncpg://trans_owner:BookMyService7@ep-sweet-surf-a1qeduoy.ap-southeast-1.aws.neon.tech/bookmyservice?options=-csearch_path%3Dtrans
7
 
8
+ #CACHE_URI=redis-11382.c305.ap-south-1-1.ec2.redns.redis-cloud.com:11382
9
 
10
  #CACHE_URI=redis-11521.crce182.ap-south-1-1.ec2.redns.redis-cloud.com:11521
11
 
12
+ CACHE_URI=localhost:6379
13
+ CACHE_K=
14
 
15
 
16
  RAZORPAY_KEY_ID=rzp_test_2UTAol2AFSV5VN
app/app.py CHANGED
@@ -2,7 +2,9 @@
2
 
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from app.routers import user_router, profile_router
 
 
6
  import logging
7
  import sys
8
 
@@ -29,6 +31,12 @@ logging.getLogger("fastapi").setLevel(logging.INFO)
29
 
30
  app = FastAPI(title="BookMyService User Management Service")
31
 
 
 
 
 
 
 
32
  app.add_middleware(
33
  CORSMiddleware,
34
  allow_origins=["*"],
@@ -39,6 +47,7 @@ app.add_middleware(
39
 
40
  app.include_router(user_router.router, prefix="/auth", tags=["user_auth"])
41
  app.include_router(profile_router.router, prefix="/profile", tags=["profile"])
 
42
 
43
  @app.get("/")
44
  def root():
 
2
 
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
+ from app.routers import user_router, profile_router, account_router
6
+ from app.middleware.rate_limiter import RateLimitMiddleware
7
+ from app.middleware.security_middleware import SecurityMiddleware
8
  import logging
9
  import sys
10
 
 
31
 
32
  app = FastAPI(title="BookMyService User Management Service")
33
 
34
+ # Add security middleware (should be added first for proper request logging)
35
+ app.add_middleware(SecurityMiddleware)
36
+
37
+ # Add rate limiting middleware
38
+ app.add_middleware(RateLimitMiddleware, calls=100, period=60)
39
+
40
  app.add_middleware(
41
  CORSMiddleware,
42
  allow_origins=["*"],
 
47
 
48
  app.include_router(user_router.router, prefix="/auth", tags=["user_auth"])
49
  app.include_router(profile_router.router, prefix="/profile", tags=["profile"])
50
+ app.include_router(account_router.router, prefix="/account", tags=["account_management"])
51
 
52
  @app.get("/")
53
  def root():
app/core/config.py CHANGED
@@ -12,19 +12,35 @@ class Settings:
12
  CACHE_URI: str = os.getenv("CACHE_URI")
13
  CACHE_K: str = os.getenv("CACHE_K")
14
 
 
15
  SECRET_KEY: str = os.getenv("SECRET_KEY", "B00Kmyservice@7")
16
  ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
17
 
 
18
  TWILIO_ACCOUNT_SID: str = os.getenv("TWILIO_ACCOUNT_SID")
19
  TWILIO_AUTH_TOKEN: str = os.getenv("TWILIO_AUTH_TOKEN")
20
  TWILIO_SMS_FROM: str = os.getenv("TWILIO_SMS_FROM")
21
 
 
22
  SMTP_HOST: str = os.getenv("SMTP_HOST")
23
  SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
24
  SMTP_USER: str = os.getenv("SMTP_USER")
25
  SMTP_PASS: str = os.getenv("SMTP_PASS")
26
  SMTP_FROM: str = os.getenv("SMTP_FROM")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def __post_init__(self):
29
  if not self.MONGO_URI or not self.DB_NAME:
30
  raise ValueError("MongoDB URI or DB_NAME not configured.")
 
12
  CACHE_URI: str = os.getenv("CACHE_URI")
13
  CACHE_K: str = os.getenv("CACHE_K")
14
 
15
+ # JWT
16
  SECRET_KEY: str = os.getenv("SECRET_KEY", "B00Kmyservice@7")
17
  ALGORITHM: str = os.getenv("ALGORITHM", "HS256")
18
 
19
+ # Twilio SMS
20
  TWILIO_ACCOUNT_SID: str = os.getenv("TWILIO_ACCOUNT_SID")
21
  TWILIO_AUTH_TOKEN: str = os.getenv("TWILIO_AUTH_TOKEN")
22
  TWILIO_SMS_FROM: str = os.getenv("TWILIO_SMS_FROM")
23
 
24
+ # SMTP Email
25
  SMTP_HOST: str = os.getenv("SMTP_HOST")
26
  SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
27
  SMTP_USER: str = os.getenv("SMTP_USER")
28
  SMTP_PASS: str = os.getenv("SMTP_PASS")
29
  SMTP_FROM: str = os.getenv("SMTP_FROM")
30
 
31
+ # OAuth Providers
32
+ GOOGLE_CLIENT_ID: str = os.getenv("GOOGLE_CLIENT_ID")
33
+ APPLE_AUDIENCE: str = os.getenv("APPLE_AUDIENCE")
34
+ FACEBOOK_APP_ID: str = os.getenv("FACEBOOK_APP_ID")
35
+ FACEBOOK_APP_SECRET: str = os.getenv("FACEBOOK_APP_SECRET")
36
+
37
+ # Security Settings
38
+ MAX_LOGIN_ATTEMPTS: int = int(os.getenv("MAX_LOGIN_ATTEMPTS", "5"))
39
+ ACCOUNT_LOCK_DURATION: int = int(os.getenv("ACCOUNT_LOCK_DURATION", "900")) # 15 minutes
40
+ OTP_VALIDITY_MINUTES: int = int(os.getenv("OTP_VALIDITY_MINUTES", "5"))
41
+ IP_RATE_LIMIT_MAX: int = int(os.getenv("IP_RATE_LIMIT_MAX", "10"))
42
+ IP_RATE_LIMIT_WINDOW: int = int(os.getenv("IP_RATE_LIMIT_WINDOW", "3600")) # 1 hour
43
+
44
  def __post_init__(self):
45
  if not self.MONGO_URI or not self.DB_NAME:
46
  raise ValueError("MongoDB URI or DB_NAME not configured.")
app/middleware/rate_limiter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, HTTPException
2
+ from starlette.middleware.base import BaseHTTPMiddleware
3
+ import time
4
+ from collections import defaultdict, deque
5
+
6
+ class RateLimitMiddleware(BaseHTTPMiddleware):
7
+ def __init__(self, app, calls: int = 100, period: int = 60):
8
+ super().__init__(app)
9
+ self.calls = calls
10
+ self.period = period
11
+ self.clients = defaultdict(deque)
12
+
13
+ async def dispatch(self, request: Request, call_next):
14
+ client_ip = request.client.host
15
+ now = time.time()
16
+
17
+ # Clean old requests
18
+ while self.clients[client_ip] and self.clients[client_ip][0] <= now - self.period:
19
+ self.clients[client_ip].popleft()
20
+
21
+ # Check rate limit
22
+ if len(self.clients[client_ip]) >= self.calls:
23
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
24
+
25
+ self.clients[client_ip].append(now)
26
+ response = await call_next(request)
27
+ return response
app/middleware/security_middleware.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, Response
2
+ from starlette.middleware.base import BaseHTTPMiddleware
3
+ from datetime import datetime
4
+ import json
5
+ import logging
6
+ from typing import Dict, Any
7
+ from app.core.nosql_client import db
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class SecurityMiddleware(BaseHTTPMiddleware):
14
+ """
15
+ Enhanced security middleware for request logging, device tracking, and security monitoring
16
+ """
17
+
18
+ def __init__(self, app):
19
+ super().__init__(app)
20
+ self.security_collection = db.security_logs
21
+ self.device_collection = db.device_tracking
22
+
23
+ def get_client_ip(self, request: Request) -> str:
24
+ """Extract client IP from request headers"""
25
+ # Check for forwarded headers first (for proxy/load balancer scenarios)
26
+ forwarded_for = request.headers.get("X-Forwarded-For")
27
+ if forwarded_for:
28
+ return forwarded_for.split(",")[0].strip()
29
+
30
+ real_ip = request.headers.get("X-Real-IP")
31
+ if real_ip:
32
+ return real_ip
33
+
34
+ # Fallback to direct client IP
35
+ return request.client.host if request.client else "unknown"
36
+
37
+ def extract_device_info(self, request: Request) -> Dict[str, Any]:
38
+ """Extract device and browser information from request headers"""
39
+ user_agent = request.headers.get("User-Agent", "")
40
+ accept_language = request.headers.get("Accept-Language", "")
41
+ accept_encoding = request.headers.get("Accept-Encoding", "")
42
+
43
+ return {
44
+ "user_agent": user_agent,
45
+ "accept_language": accept_language,
46
+ "accept_encoding": accept_encoding,
47
+ "platform": self._parse_platform(user_agent),
48
+ "browser": self._parse_browser(user_agent)
49
+ }
50
+
51
+ def _parse_platform(self, user_agent: str) -> str:
52
+ """Parse platform from user agent string"""
53
+ user_agent_lower = user_agent.lower()
54
+
55
+ if "windows" in user_agent_lower:
56
+ return "Windows"
57
+ elif "macintosh" in user_agent_lower or "mac os" in user_agent_lower:
58
+ return "macOS"
59
+ elif "linux" in user_agent_lower:
60
+ return "Linux"
61
+ elif "android" in user_agent_lower:
62
+ return "Android"
63
+ elif "iphone" in user_agent_lower or "ipad" in user_agent_lower:
64
+ return "iOS"
65
+ else:
66
+ return "Unknown"
67
+
68
+ def _parse_browser(self, user_agent: str) -> str:
69
+ """Parse browser from user agent string"""
70
+ user_agent_lower = user_agent.lower()
71
+
72
+ if "chrome" in user_agent_lower and "edg" not in user_agent_lower:
73
+ return "Chrome"
74
+ elif "firefox" in user_agent_lower:
75
+ return "Firefox"
76
+ elif "safari" in user_agent_lower and "chrome" not in user_agent_lower:
77
+ return "Safari"
78
+ elif "edg" in user_agent_lower:
79
+ return "Edge"
80
+ elif "opera" in user_agent_lower:
81
+ return "Opera"
82
+ else:
83
+ return "Unknown"
84
+
85
+ def is_sensitive_endpoint(self, path: str) -> bool:
86
+ """Check if the endpoint is security-sensitive and should be logged"""
87
+ sensitive_paths = [
88
+ "/auth/",
89
+ "/login",
90
+ "/register",
91
+ "/otp",
92
+ "/oauth",
93
+ "/profile",
94
+ "/account",
95
+ "/security"
96
+ ]
97
+
98
+ return any(sensitive_path in path for sensitive_path in sensitive_paths)
99
+
100
+ async def log_security_event(self, request: Request, response: Response,
101
+ processing_time: float, client_ip: str,
102
+ device_info: Dict[str, Any]):
103
+ """Log security-relevant events to database"""
104
+ try:
105
+ # Only log sensitive endpoints or failed requests
106
+ if not (self.is_sensitive_endpoint(str(request.url.path)) or response.status_code >= 400):
107
+ return
108
+
109
+ log_entry = {
110
+ "timestamp": datetime.utcnow(),
111
+ "method": request.method,
112
+ "path": str(request.url.path),
113
+ "query_params": dict(request.query_params),
114
+ "client_ip": client_ip,
115
+ "status_code": response.status_code,
116
+ "processing_time_ms": round(processing_time * 1000, 2),
117
+ "device_info": device_info,
118
+ "headers": {
119
+ "user_agent": request.headers.get("User-Agent", ""),
120
+ "referer": request.headers.get("Referer", ""),
121
+ "content_type": request.headers.get("Content-Type", "")
122
+ },
123
+ "is_suspicious": self._detect_suspicious_activity(request, response, client_ip)
124
+ }
125
+
126
+ # Add user ID if available from JWT token
127
+ auth_header = request.headers.get("Authorization")
128
+ if auth_header and auth_header.startswith("Bearer "):
129
+ try:
130
+ from app.utils.jwt import decode_token
131
+ token = auth_header.split(" ")[1]
132
+ payload = decode_token(token)
133
+ log_entry["user_id"] = payload.get("user_id")
134
+ except Exception:
135
+ pass # Token might be invalid or expired
136
+
137
+ await self.security_collection.insert_one(log_entry)
138
+
139
+ except Exception as e:
140
+ logger.error(f"Failed to log security event: {str(e)}")
141
+
142
+ async def track_device(self, client_ip: str, device_info: Dict[str, Any],
143
+ user_id: str = None):
144
+ """Track device information for security monitoring"""
145
+ try:
146
+ device_fingerprint = f"{client_ip}_{device_info.get('user_agent', '')[:100]}"
147
+
148
+ device_entry = {
149
+ "device_fingerprint": device_fingerprint,
150
+ "client_ip": client_ip,
151
+ "device_info": device_info,
152
+ "first_seen": datetime.utcnow(),
153
+ "last_seen": datetime.utcnow(),
154
+ "user_id": user_id,
155
+ "access_count": 1,
156
+ "is_trusted": False
157
+ }
158
+
159
+ # Update or insert device tracking
160
+ await self.device_collection.update_one(
161
+ {"device_fingerprint": device_fingerprint},
162
+ {
163
+ "$set": {
164
+ "last_seen": datetime.utcnow(),
165
+ "device_info": device_info
166
+ },
167
+ "$inc": {"access_count": 1},
168
+ "$setOnInsert": {
169
+ "device_fingerprint": device_fingerprint,
170
+ "client_ip": client_ip,
171
+ "first_seen": datetime.utcnow(),
172
+ "user_id": user_id,
173
+ "is_trusted": False
174
+ }
175
+ },
176
+ upsert=True
177
+ )
178
+
179
+ except Exception as e:
180
+ logger.error(f"Failed to track device: {str(e)}")
181
+
182
+ def _detect_suspicious_activity(self, request: Request, response: Response,
183
+ client_ip: str) -> bool:
184
+ """Detect potentially suspicious activity patterns"""
185
+ suspicious_indicators = []
186
+
187
+ # Check for multiple failed login attempts
188
+ if response.status_code == 401 and "login" in str(request.url.path):
189
+ suspicious_indicators.append("failed_login")
190
+
191
+ # Check for unusual user agent patterns
192
+ user_agent = request.headers.get("User-Agent", "")
193
+ if not user_agent or len(user_agent) < 10:
194
+ suspicious_indicators.append("suspicious_user_agent")
195
+
196
+ # Check for rapid requests (basic detection)
197
+ if hasattr(request.state, "request_count") and request.state.request_count > 10:
198
+ suspicious_indicators.append("rapid_requests")
199
+
200
+ # Check for access to sensitive endpoints without proper authentication
201
+ if (self.is_sensitive_endpoint(str(request.url.path)) and
202
+ response.status_code == 403 and
203
+ not request.headers.get("Authorization")):
204
+ suspicious_indicators.append("unauthorized_sensitive_access")
205
+
206
+ return len(suspicious_indicators) > 0
207
+
208
+ async def dispatch(self, request: Request, call_next):
209
+ """Main middleware dispatch method"""
210
+ start_time = datetime.utcnow()
211
+
212
+ # Extract client information
213
+ client_ip = self.get_client_ip(request)
214
+ device_info = self.extract_device_info(request)
215
+
216
+ # Process the request
217
+ response = await call_next(request)
218
+
219
+ # Calculate processing time
220
+ end_time = datetime.utcnow()
221
+ processing_time = (end_time - start_time).total_seconds()
222
+
223
+ # Log security events asynchronously
224
+ try:
225
+ await self.log_security_event(request, response, processing_time,
226
+ client_ip, device_info)
227
+ await self.track_device(client_ip, device_info)
228
+ except Exception as e:
229
+ logger.error(f"Security middleware error: {str(e)}")
230
+
231
+ # Add security headers to response
232
+ response.headers["X-Content-Type-Options"] = "nosniff"
233
+ response.headers["X-Frame-Options"] = "DENY"
234
+ response.headers["X-XSS-Protection"] = "1; mode=block"
235
+ response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
236
+
237
+ return response
app/models/otp_model.py CHANGED
@@ -11,16 +11,26 @@ class BookMyServiceOTPModel:
11
  OTP_TTL = 300 # 5 minutes
12
  RATE_LIMIT_MAX = 3
13
  RATE_LIMIT_WINDOW = 600 # 10 minutes
 
 
 
 
 
14
 
15
  @staticmethod
16
- async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL):
17
- logger.info(f"Storing OTP for identifier: {identifier}")
18
 
19
  try:
20
  redis = await get_redis()
21
  logger.debug(f"Redis connection established for OTP storage")
22
 
23
- # Rate limit: max 3 OTPs per 10 minutes
 
 
 
 
 
24
  rate_key = f"otp_rate_limit:{identifier}"
25
  logger.debug(f"Checking rate limit with key: {rate_key}")
26
 
@@ -34,6 +44,17 @@ class BookMyServiceOTPModel:
34
  logger.warning(f"Rate limit exceeded for {identifier}: {attempts} attempts")
35
  raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.")
36
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Store OTP
38
  otp_key = f"bms_otp:{identifier}"
39
  await redis.setex(otp_key, ttl, otp)
@@ -62,14 +83,19 @@ class BookMyServiceOTPModel:
62
  raise HTTPException(status_code=500, detail="SMS failed and no email fallback available.")
63
  '''
64
  @staticmethod
65
- async def verify_otp(identifier: str, otp: str):
66
- logger.info(f"Verifying OTP for identifier: {identifier}")
67
  logger.debug(f"Provided OTP: {otp}")
68
 
69
  try:
70
  redis = await get_redis()
71
  logger.debug("Redis connection established for OTP verification")
72
 
 
 
 
 
 
73
  key = f"bms_otp:{identifier}"
74
  logger.debug(f"Looking up OTP with key: {key}")
75
 
@@ -81,15 +107,24 @@ class BookMyServiceOTPModel:
81
  if stored == otp:
82
  logger.info(f"OTP verification successful for {identifier}")
83
  await redis.delete(key)
 
 
84
  logger.debug(f"OTP deleted from Redis after successful verification")
85
  return True
86
  else:
87
  logger.warning(f"OTP mismatch for {identifier}: provided='{otp}' vs stored='{stored}'")
 
 
88
  return False
89
  else:
90
  logger.warning(f"No OTP found in Redis for identifier: {identifier} with key: {key}")
 
 
91
  return False
92
 
 
 
 
93
  except Exception as e:
94
  logger.error(f"Error verifying OTP for {identifier}: {str(e)}", exc_info=True)
95
  return False
@@ -101,4 +136,95 @@ class BookMyServiceOTPModel:
101
  otp = await redis.get(key)
102
  if otp:
103
  return otp
104
- raise HTTPException(status_code=404, detail="OTP not found or expired")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  OTP_TTL = 300 # 5 minutes
12
  RATE_LIMIT_MAX = 3
13
  RATE_LIMIT_WINDOW = 600 # 10 minutes
14
+ IP_RATE_LIMIT_MAX = 10 # Max 10 OTPs per IP per hour
15
+ IP_RATE_LIMIT_WINDOW = 3600 # 1 hour
16
+ FAILED_ATTEMPTS_MAX = 5 # Max 5 failed attempts before lock
17
+ FAILED_ATTEMPTS_WINDOW = 3600 # 1 hour
18
+ ACCOUNT_LOCK_DURATION = 1800 # 30 minutes
19
 
20
  @staticmethod
21
+ async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL, client_ip: str = None):
22
+ logger.info(f"Storing OTP for identifier: {identifier}, IP: {client_ip}")
23
 
24
  try:
25
  redis = await get_redis()
26
  logger.debug(f"Redis connection established for OTP storage")
27
 
28
+ # Check if account is locked
29
+ if await BookMyServiceOTPModel.is_account_locked(identifier):
30
+ logger.warning(f"Account locked for identifier: {identifier}")
31
+ raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
32
+
33
+ # Rate limit: max 3 OTPs per identifier per 10 minutes
34
  rate_key = f"otp_rate_limit:{identifier}"
35
  logger.debug(f"Checking rate limit with key: {rate_key}")
36
 
 
44
  logger.warning(f"Rate limit exceeded for {identifier}: {attempts} attempts")
45
  raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.")
46
 
47
+ # IP-based rate limiting
48
+ if client_ip:
49
+ ip_rate_key = f"otp_ip_rate_limit:{client_ip}"
50
+ ip_attempts = await redis.incr(ip_rate_key)
51
+
52
+ if ip_attempts == 1:
53
+ await redis.expire(ip_rate_key, BookMyServiceOTPModel.IP_RATE_LIMIT_WINDOW)
54
+ elif ip_attempts > BookMyServiceOTPModel.IP_RATE_LIMIT_MAX:
55
+ logger.warning(f"IP rate limit exceeded for {client_ip}: {ip_attempts} attempts")
56
+ raise HTTPException(status_code=429, detail="Too many OTP requests from this IP address")
57
+
58
  # Store OTP
59
  otp_key = f"bms_otp:{identifier}"
60
  await redis.setex(otp_key, ttl, otp)
 
83
  raise HTTPException(status_code=500, detail="SMS failed and no email fallback available.")
84
  '''
85
  @staticmethod
86
+ async def verify_otp(identifier: str, otp: str, client_ip: str = None):
87
+ logger.info(f"Verifying OTP for identifier: {identifier}, IP: {client_ip}")
88
  logger.debug(f"Provided OTP: {otp}")
89
 
90
  try:
91
  redis = await get_redis()
92
  logger.debug("Redis connection established for OTP verification")
93
 
94
+ # Check if account is locked
95
+ if await BookMyServiceOTPModel.is_account_locked(identifier):
96
+ logger.warning(f"Account locked for identifier: {identifier}")
97
+ raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
98
+
99
  key = f"bms_otp:{identifier}"
100
  logger.debug(f"Looking up OTP with key: {key}")
101
 
 
107
  if stored == otp:
108
  logger.info(f"OTP verification successful for {identifier}")
109
  await redis.delete(key)
110
+ # Clear failed attempts on successful verification
111
+ await BookMyServiceOTPModel.clear_failed_attempts(identifier)
112
  logger.debug(f"OTP deleted from Redis after successful verification")
113
  return True
114
  else:
115
  logger.warning(f"OTP mismatch for {identifier}: provided='{otp}' vs stored='{stored}'")
116
+ # Track failed attempt
117
+ await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
118
  return False
119
  else:
120
  logger.warning(f"No OTP found in Redis for identifier: {identifier} with key: {key}")
121
+ # Track failed attempt for expired/non-existent OTP
122
+ await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
123
  return False
124
 
125
+ except HTTPException as e:
126
+ logger.error(f"HTTP error verifying OTP for {identifier}: {e.status_code} - {e.detail}")
127
+ raise e
128
  except Exception as e:
129
  logger.error(f"Error verifying OTP for {identifier}: {str(e)}", exc_info=True)
130
  return False
 
136
  otp = await redis.get(key)
137
  if otp:
138
  return otp
139
+ raise HTTPException(status_code=404, detail="OTP not found or expired")
140
+
141
+ @staticmethod
142
+ async def track_failed_attempt(identifier: str, client_ip: str = None):
143
+ """Track failed OTP verification attempts"""
144
+ logger.info(f"Tracking failed attempt for identifier: {identifier}, IP: {client_ip}")
145
+
146
+ try:
147
+ redis = await get_redis()
148
+
149
+ # Track failed attempts for identifier
150
+ failed_key = f"failed_otp:{identifier}"
151
+ attempts = await redis.incr(failed_key)
152
+
153
+ if attempts == 1:
154
+ await redis.expire(failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW)
155
+
156
+ # Lock account if too many failed attempts
157
+ if attempts >= BookMyServiceOTPModel.FAILED_ATTEMPTS_MAX:
158
+ await BookMyServiceOTPModel.lock_account(identifier)
159
+ logger.warning(f"Account locked for {identifier} after {attempts} failed attempts")
160
+
161
+ # Track IP-based failed attempts
162
+ if client_ip:
163
+ ip_failed_key = f"failed_otp_ip:{client_ip}"
164
+ ip_attempts = await redis.incr(ip_failed_key)
165
+
166
+ if ip_attempts == 1:
167
+ await redis.expire(ip_failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW)
168
+
169
+ logger.debug(f"IP {client_ip} failed attempts: {ip_attempts}")
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error tracking failed attempt for {identifier}: {str(e)}", exc_info=True)
173
+
174
+ @staticmethod
175
+ async def clear_failed_attempts(identifier: str):
176
+ """Clear failed attempts counter on successful verification"""
177
+ try:
178
+ redis = await get_redis()
179
+ failed_key = f"failed_otp:{identifier}"
180
+ await redis.delete(failed_key)
181
+ logger.debug(f"Cleared failed attempts for {identifier}")
182
+ except Exception as e:
183
+ logger.error(f"Error clearing failed attempts for {identifier}: {str(e)}", exc_info=True)
184
+
185
+ @staticmethod
186
+ async def lock_account(identifier: str):
187
+ """Lock account temporarily"""
188
+ try:
189
+ redis = await get_redis()
190
+ lock_key = f"account_locked:{identifier}"
191
+ await redis.setex(lock_key, BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION, "locked")
192
+ logger.info(f"Account locked for {identifier} for {BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION} seconds")
193
+ except Exception as e:
194
+ logger.error(f"Error locking account for {identifier}: {str(e)}", exc_info=True)
195
+
196
+ @staticmethod
197
+ async def is_account_locked(identifier: str) -> bool:
198
+ """Check if account is currently locked"""
199
+ try:
200
+ redis = await get_redis()
201
+ lock_key = f"account_locked:{identifier}"
202
+ locked = await redis.get(lock_key)
203
+ return locked is not None
204
+ except Exception as e:
205
+ logger.error(f"Error checking account lock for {identifier}: {str(e)}", exc_info=True)
206
+ return False
207
+
208
+ @staticmethod
209
+ async def get_rate_limit_count(rate_key: str) -> int:
210
+ """Get current rate limit count for a key"""
211
+ try:
212
+ redis = await get_redis()
213
+ count = await redis.get(rate_key)
214
+ return int(count) if count else 0
215
+ except Exception as e:
216
+ logger.error(f"Error getting rate limit count for {rate_key}: {str(e)}", exc_info=True)
217
+ return 0
218
+
219
+ @staticmethod
220
+ async def increment_rate_limit(rate_key: str, window: int) -> int:
221
+ """Increment rate limit counter with expiry"""
222
+ try:
223
+ redis = await get_redis()
224
+ count = await redis.incr(rate_key)
225
+ if count == 1:
226
+ await redis.expire(rate_key, window)
227
+ return count
228
+ except Exception as e:
229
+ logger.error(f"Error incrementing rate limit for {rate_key}: {str(e)}", exc_info=True)
230
+ return 0
app/models/social_account_model.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+ from app.core.nosql_client import db
3
+ from datetime import datetime
4
+ from typing import Optional, List, Dict, Any
5
+ import logging
6
+
7
+ logger = logging.getLogger("social_account_model")
8
+
9
+ class SocialAccountModel:
10
+ """Model for managing social login accounts and linking"""
11
+
12
+ collection = db["social_accounts"]
13
+
14
+ @staticmethod
15
+ async def create_social_account(user_id: str, provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> str:
16
+ """Create a new social account record"""
17
+ try:
18
+ social_account = {
19
+ "user_id": user_id,
20
+ "provider": provider,
21
+ "provider_user_id": provider_user_id,
22
+ "email": user_info.get("email"),
23
+ "name": user_info.get("name"),
24
+ "picture": user_info.get("picture"),
25
+ "profile_data": user_info,
26
+ "created_at": datetime.utcnow(),
27
+ "updated_at": datetime.utcnow(),
28
+ "is_active": True,
29
+ "last_login": datetime.utcnow()
30
+ }
31
+
32
+ result = await SocialAccountModel.collection.insert_one(social_account)
33
+ logger.info(f"Created social account for user {user_id} with provider {provider}")
34
+ return str(result.inserted_id)
35
+
36
+ except Exception as e:
37
+ logger.error(f"Error creating social account: {str(e)}", exc_info=True)
38
+ raise HTTPException(status_code=500, detail="Failed to create social account")
39
+
40
+ @staticmethod
41
+ async def find_by_provider_and_user_id(provider: str, provider_user_id: str) -> Optional[Dict[str, Any]]:
42
+ """Find social account by provider and provider user ID"""
43
+ try:
44
+ account = await SocialAccountModel.collection.find_one({
45
+ "provider": provider,
46
+ "provider_user_id": provider_user_id,
47
+ "is_active": True
48
+ })
49
+ return account
50
+ except Exception as e:
51
+ logger.error(f"Error finding social account: {str(e)}", exc_info=True)
52
+ return None
53
+
54
+ @staticmethod
55
+ async def find_by_user_id(user_id: str) -> List[Dict[str, Any]]:
56
+ """Find all social accounts for a user"""
57
+ try:
58
+ cursor = SocialAccountModel.collection.find({
59
+ "user_id": user_id,
60
+ "is_active": True
61
+ })
62
+ accounts = await cursor.to_list(length=None)
63
+ return accounts
64
+ except Exception as e:
65
+ logger.error(f"Error finding social accounts for user {user_id}: {str(e)}", exc_info=True)
66
+ return []
67
+
68
+ @staticmethod
69
+ async def update_social_account(provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> bool:
70
+ """Update social account with latest user info"""
71
+ try:
72
+ update_data = {
73
+ "email": user_info.get("email"),
74
+ "name": user_info.get("name"),
75
+ "picture": user_info.get("picture"),
76
+ "profile_data": user_info,
77
+ "updated_at": datetime.utcnow(),
78
+ "last_login": datetime.utcnow()
79
+ }
80
+
81
+ result = await SocialAccountModel.collection.update_one(
82
+ {
83
+ "provider": provider,
84
+ "provider_user_id": provider_user_id,
85
+ "is_active": True
86
+ },
87
+ {"$set": update_data}
88
+ )
89
+
90
+ return result.modified_count > 0
91
+
92
+ except Exception as e:
93
+ logger.error(f"Error updating social account: {str(e)}", exc_info=True)
94
+ return False
95
+
96
+ @staticmethod
97
+ async def link_social_account(user_id: str, provider: str, provider_user_id: str, user_info: Dict[str, Any]) -> bool:
98
+ """Link a social account to an existing user"""
99
+ try:
100
+ # Check if this social account is already linked to another user
101
+ existing_account = await SocialAccountModel.find_by_provider_and_user_id(provider, provider_user_id)
102
+
103
+ if existing_account and existing_account["user_id"] != user_id:
104
+ logger.warning(f"Social account {provider}:{provider_user_id} already linked to user {existing_account['user_id']}")
105
+ raise HTTPException(
106
+ status_code=409,
107
+ detail=f"This {provider} account is already linked to another user"
108
+ )
109
+
110
+ if existing_account and existing_account["user_id"] == user_id:
111
+ # Update existing account
112
+ await SocialAccountModel.update_social_account(provider, provider_user_id, user_info)
113
+ return True
114
+
115
+ # Create new social account link
116
+ await SocialAccountModel.create_social_account(user_id, provider, provider_user_id, user_info)
117
+ return True
118
+
119
+ except HTTPException:
120
+ raise
121
+ except Exception as e:
122
+ logger.error(f"Error linking social account: {str(e)}", exc_info=True)
123
+ raise HTTPException(status_code=500, detail="Failed to link social account")
124
+
125
+ @staticmethod
126
+ async def unlink_social_account(user_id: str, provider: str) -> bool:
127
+ """Unlink a social account from a user"""
128
+ try:
129
+ result = await SocialAccountModel.collection.update_one(
130
+ {
131
+ "user_id": user_id,
132
+ "provider": provider,
133
+ "is_active": True
134
+ },
135
+ {
136
+ "$set": {
137
+ "is_active": False,
138
+ "updated_at": datetime.utcnow()
139
+ }
140
+ }
141
+ )
142
+
143
+ if result.modified_count > 0:
144
+ logger.info(f"Unlinked {provider} account for user {user_id}")
145
+ return True
146
+ else:
147
+ logger.warning(f"No active {provider} account found for user {user_id}")
148
+ return False
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error unlinking social account: {str(e)}", exc_info=True)
152
+ return False
153
+
154
+ @staticmethod
155
+ async def get_profile_picture(user_id: str, preferred_provider: str = None) -> Optional[str]:
156
+ """Get user's profile picture from social accounts"""
157
+ try:
158
+ query = {"user_id": user_id, "is_active": True}
159
+
160
+ # If preferred provider specified, try that first
161
+ if preferred_provider:
162
+ account = await SocialAccountModel.collection.find_one({
163
+ **query,
164
+ "provider": preferred_provider,
165
+ "picture": {"$exists": True, "$ne": None}
166
+ })
167
+ if account and account.get("picture"):
168
+ return account["picture"]
169
+
170
+ # Otherwise, get any account with a profile picture
171
+ account = await SocialAccountModel.collection.find_one({
172
+ **query,
173
+ "picture": {"$exists": True, "$ne": None}
174
+ })
175
+
176
+ return account.get("picture") if account else None
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error getting profile picture for user {user_id}: {str(e)}", exc_info=True)
180
+ return None
181
+
182
+ @staticmethod
183
+ async def get_social_account_summary(user_id: str) -> Dict[str, Any]:
184
+ """Get summary of all linked social accounts for a user"""
185
+ try:
186
+ accounts = await SocialAccountModel.find_by_user_id(user_id)
187
+
188
+ summary = {
189
+ "linked_accounts": [],
190
+ "total_accounts": len(accounts),
191
+ "profile_picture": None
192
+ }
193
+
194
+ for account in accounts:
195
+ summary["linked_accounts"].append({
196
+ "provider": account["provider"],
197
+ "email": account.get("email"),
198
+ "name": account.get("name"),
199
+ "linked_at": account["created_at"],
200
+ "last_login": account.get("last_login")
201
+ })
202
+
203
+ # Set profile picture if available
204
+ if not summary["profile_picture"] and account.get("picture"):
205
+ summary["profile_picture"] = account["picture"]
206
+
207
+ return summary
208
+
209
+ except Exception as e:
210
+ logger.error(f"Error getting social account summary for user {user_id}: {str(e)}", exc_info=True)
211
+ return {"linked_accounts": [], "total_accounts": 0, "profile_picture": None}
212
+
213
+ @staticmethod
214
+ async def merge_social_accounts(primary_user_id: str, secondary_user_id: str) -> bool:
215
+ """Merge social accounts from secondary user to primary user"""
216
+ try:
217
+ # Get all social accounts from secondary user
218
+ secondary_accounts = await SocialAccountModel.find_by_user_id(secondary_user_id)
219
+
220
+ for account in secondary_accounts:
221
+ # Check if primary user already has this provider linked
222
+ existing = await SocialAccountModel.collection.find_one({
223
+ "user_id": primary_user_id,
224
+ "provider": account["provider"],
225
+ "is_active": True
226
+ })
227
+
228
+ if not existing:
229
+ # Transfer the account to primary user
230
+ await SocialAccountModel.collection.update_one(
231
+ {"_id": account["_id"]},
232
+ {
233
+ "$set": {
234
+ "user_id": primary_user_id,
235
+ "updated_at": datetime.utcnow()
236
+ }
237
+ }
238
+ )
239
+ logger.info(f"Transferred {account['provider']} account from user {secondary_user_id} to {primary_user_id}")
240
+ else:
241
+ # Deactivate the secondary account
242
+ await SocialAccountModel.collection.update_one(
243
+ {"_id": account["_id"]},
244
+ {
245
+ "$set": {
246
+ "is_active": False,
247
+ "updated_at": datetime.utcnow()
248
+ }
249
+ }
250
+ )
251
+ logger.info(f"Deactivated duplicate {account['provider']} account for user {secondary_user_id}")
252
+
253
+ return True
254
+
255
+ except Exception as e:
256
+ logger.error(f"Error merging social accounts: {str(e)}", exc_info=True)
257
+ return False
app/models/social_security_model.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ import logging
3
+ from app.core.cache_client import get_redis
4
+ from fastapi import HTTPException
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class SocialSecurityModel:
9
+ """Model for handling social login security features"""
10
+
11
+ # Rate limiting constants
12
+ OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour
13
+ OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
14
+
15
+ # Failed attempt tracking
16
+ OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP
17
+ OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes
18
+ OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP
19
+
20
+ @staticmethod
21
+ async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool:
22
+ """Check if OAuth rate limit is exceeded for IP and provider"""
23
+ if not client_ip:
24
+ return True # Allow if no IP provided
25
+
26
+ try:
27
+ redis = await get_redis()
28
+ rate_key = f"oauth_rate:{client_ip}:{provider}"
29
+
30
+ current_count = await redis.get(rate_key)
31
+ if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX:
32
+ logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}")
33
+ return False
34
+
35
+ return True
36
+
37
+ except Exception as e:
38
+ logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True)
39
+ return True # Allow on error to avoid blocking legitimate users
40
+
41
+ @staticmethod
42
+ async def increment_oauth_rate_limit(client_ip: str, provider: str):
43
+ """Increment OAuth rate limit counter"""
44
+ if not client_ip:
45
+ return
46
+
47
+ try:
48
+ redis = await get_redis()
49
+ rate_key = f"oauth_rate:{client_ip}:{provider}"
50
+
51
+ count = await redis.incr(rate_key)
52
+ if count == 1:
53
+ await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW)
54
+
55
+ logger.debug(f"OAuth rate limit count for {client_ip}:{provider} = {count}")
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True)
59
+
60
+ @staticmethod
61
+ async def track_oauth_failed_attempt(client_ip: str, provider: str):
62
+ """Track failed OAuth verification attempts"""
63
+ if not client_ip:
64
+ return
65
+
66
+ try:
67
+ redis = await get_redis()
68
+ failed_key = f"oauth_failed:{client_ip}:{provider}"
69
+
70
+ attempts = await redis.incr(failed_key)
71
+ if attempts == 1:
72
+ await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW)
73
+
74
+ # Lock IP if too many failed attempts
75
+ if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX:
76
+ await SocialSecurityModel.lock_oauth_ip(client_ip, provider)
77
+ logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts")
78
+
79
+ logger.debug(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}")
80
+
81
+ except Exception as e:
82
+ logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True)
83
+
84
+ @staticmethod
85
+ async def lock_oauth_ip(client_ip: str, provider: str):
86
+ """Lock IP for OAuth attempts on specific provider"""
87
+ try:
88
+ redis = await get_redis()
89
+ lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
90
+ await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked")
91
+ logger.info(f"IP {client_ip} locked for OAuth provider {provider}")
92
+ except Exception as e:
93
+ logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True)
94
+
95
+ @staticmethod
96
+ async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool:
97
+ """Check if IP is locked for OAuth attempts on specific provider"""
98
+ if not client_ip:
99
+ return False
100
+
101
+ try:
102
+ redis = await get_redis()
103
+ lock_key = f"oauth_ip_locked:{client_ip}:{provider}"
104
+ locked = await redis.get(lock_key)
105
+ return locked is not None
106
+ except Exception as e:
107
+ logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True)
108
+ return False
109
+
110
+ @staticmethod
111
+ async def clear_oauth_failed_attempts(client_ip: str, provider: str):
112
+ """Clear failed OAuth attempts on successful verification"""
113
+ if not client_ip:
114
+ return
115
+
116
+ try:
117
+ redis = await get_redis()
118
+ failed_key = f"oauth_failed:{client_ip}:{provider}"
119
+ await redis.delete(failed_key)
120
+ logger.debug(f"Cleared OAuth failed attempts for {client_ip}:{provider}")
121
+ except Exception as e:
122
+ logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True)
123
+
124
+ @staticmethod
125
+ async def validate_oauth_token_format(token: str, provider: str) -> bool:
126
+ """Basic validation of OAuth token format"""
127
+ if not token or not isinstance(token, str):
128
+ return False
129
+
130
+ # Basic length and format checks
131
+ if provider == "google":
132
+ # Google ID tokens are typically JWT format
133
+ return len(token) > 100 and token.count('.') == 2
134
+ elif provider == "apple":
135
+ # Apple ID tokens are also JWT format
136
+ return len(token) > 100 and token.count('.') == 2
137
+ elif provider == "facebook":
138
+ # Facebook access tokens are typically shorter
139
+ return len(token) > 20 and len(token) < 500
140
+
141
+ return True # Allow unknown providers
142
+
143
+ @staticmethod
144
+ async def log_oauth_attempt(client_ip: str, provider: str, success: bool, user_id: str = None):
145
+ """Log OAuth authentication attempts for security monitoring"""
146
+ try:
147
+ redis = await get_redis()
148
+ log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}"
149
+
150
+ log_entry = {
151
+ "timestamp": datetime.utcnow().isoformat(),
152
+ "ip": client_ip,
153
+ "provider": provider,
154
+ "success": success,
155
+ "user_id": user_id
156
+ }
157
+
158
+ # Store as JSON string in Redis list
159
+ import json
160
+ await redis.lpush(log_key, json.dumps(log_entry))
161
+
162
+ # Keep only last 1000 entries per day
163
+ await redis.ltrim(log_key, 0, 999)
164
+
165
+ # Set expiry for 30 days
166
+ await redis.expire(log_key, 30 * 24 * 3600)
167
+
168
+ logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}")
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True)
app/routers/account_router.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Request, Query
2
+ from fastapi.security import HTTPBearer
3
+ from typing import List, Optional
4
+ from datetime import datetime, timedelta
5
+ import logging
6
+
7
+ from app.schemas.user_schema import (
8
+ LinkSocialAccountRequest, UnlinkSocialAccountRequest,
9
+ SocialAccountSummary, LoginHistoryResponse, SecuritySettingsResponse,
10
+ TokenResponse
11
+ )
12
+ from app.services.account_service import AccountService
13
+ from app.utils.jwt import decode_token
14
+
15
+ # Configure logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ router = APIRouter()
19
+ security = HTTPBearer()
20
+
21
+ def get_current_user(token: str = Depends(security)):
22
+ """Extract user ID from JWT token"""
23
+ try:
24
+ payload = decode_token(token.credentials)
25
+ user_id = payload.get("user_id")
26
+ if not user_id:
27
+ raise HTTPException(status_code=401, detail="Invalid token")
28
+ return user_id
29
+ except Exception as e:
30
+ logger.error(f"Token validation error: {str(e)}")
31
+ raise HTTPException(status_code=401, detail="Invalid or expired token")
32
+
33
+ def get_client_ip(request: Request) -> str:
34
+ """Extract client IP from request"""
35
+ forwarded_for = request.headers.get("X-Forwarded-For")
36
+ if forwarded_for:
37
+ return forwarded_for.split(",")[0].strip()
38
+
39
+ real_ip = request.headers.get("X-Real-IP")
40
+ if real_ip:
41
+ return real_ip
42
+
43
+ return request.client.host if request.client else "unknown"
44
+
45
+ @router.get("/social-accounts", response_model=SocialAccountSummary)
46
+ async def get_social_accounts(user_id: str = Depends(get_current_user)):
47
+ """Get all linked social accounts for the current user"""
48
+ try:
49
+ account_service = AccountService()
50
+ summary = await account_service.get_social_account_summary(user_id)
51
+ return summary
52
+ except Exception as e:
53
+ logger.error(f"Error fetching social accounts for user {user_id}: {str(e)}")
54
+ raise HTTPException(status_code=500, detail="Failed to fetch social accounts")
55
+
56
+ @router.post("/link-social-account", response_model=dict)
57
+ async def link_social_account(
58
+ request: LinkSocialAccountRequest,
59
+ req: Request,
60
+ user_id: str = Depends(get_current_user)
61
+ ):
62
+ """Link a new social account to the current user"""
63
+ try:
64
+ client_ip = get_client_ip(req)
65
+ account_service = AccountService()
66
+
67
+ result = await account_service.link_social_account(
68
+ user_id=user_id,
69
+ provider=request.provider,
70
+ token=request.token,
71
+ client_ip=client_ip
72
+ )
73
+
74
+ return {"message": f"Successfully linked {request.provider} account", "result": result}
75
+ except ValueError as e:
76
+ logger.warning(f"Invalid link request for user {user_id}: {str(e)}")
77
+ raise HTTPException(status_code=400, detail=str(e))
78
+ except Exception as e:
79
+ logger.error(f"Error linking social account for user {user_id}: {str(e)}")
80
+ raise HTTPException(status_code=500, detail="Failed to link social account")
81
+
82
+ @router.delete("/unlink-social-account", response_model=dict)
83
+ async def unlink_social_account(
84
+ request: UnlinkSocialAccountRequest,
85
+ user_id: str = Depends(get_current_user)
86
+ ):
87
+ """Unlink a social account from the current user"""
88
+ try:
89
+ account_service = AccountService()
90
+
91
+ result = await account_service.unlink_social_account(
92
+ user_id=user_id,
93
+ provider=request.provider
94
+ )
95
+
96
+ return {"message": f"Successfully unlinked {request.provider} account", "result": result}
97
+ except ValueError as e:
98
+ logger.warning(f"Invalid unlink request for user {user_id}: {str(e)}")
99
+ raise HTTPException(status_code=400, detail=str(e))
100
+ except Exception as e:
101
+ logger.error(f"Error unlinking social account for user {user_id}: {str(e)}")
102
+ raise HTTPException(status_code=500, detail="Failed to unlink social account")
103
+
104
+ @router.get("/login-history", response_model=LoginHistoryResponse)
105
+ async def get_login_history(
106
+ page: int = Query(1, ge=1, description="Page number"),
107
+ per_page: int = Query(10, ge=1, le=50, description="Items per page"),
108
+ days: int = Query(30, ge=1, le=365, description="Number of days to look back"),
109
+ user_id: str = Depends(get_current_user)
110
+ ):
111
+ """Get login history for the current user"""
112
+ try:
113
+ account_service = AccountService()
114
+
115
+ history = await account_service.get_login_history(
116
+ user_id=user_id,
117
+ page=page,
118
+ per_page=per_page,
119
+ days=days
120
+ )
121
+
122
+ return history
123
+ except Exception as e:
124
+ logger.error(f"Error fetching login history for user {user_id}: {str(e)}")
125
+ raise HTTPException(status_code=500, detail="Failed to fetch login history")
126
+
127
+ @router.get("/security-settings", response_model=SecuritySettingsResponse)
128
+ async def get_security_settings(user_id: str = Depends(get_current_user)):
129
+ """Get security settings and status for the current user"""
130
+ try:
131
+ account_service = AccountService()
132
+
133
+ settings = await account_service.get_security_settings(user_id)
134
+
135
+ return settings
136
+ except Exception as e:
137
+ logger.error(f"Error fetching security settings for user {user_id}: {str(e)}")
138
+ raise HTTPException(status_code=500, detail="Failed to fetch security settings")
139
+
140
+ @router.post("/merge-accounts", response_model=dict)
141
+ async def merge_social_accounts(
142
+ target_user_id: str,
143
+ req: Request,
144
+ user_id: str = Depends(get_current_user)
145
+ ):
146
+ """Merge social accounts from another user (admin function or user-initiated)"""
147
+ try:
148
+ # For security, only allow users to merge their own accounts or implement admin check
149
+ if user_id != target_user_id:
150
+ # In a real implementation, you'd check if the current user is an admin
151
+ # or if they have proper authorization to merge accounts
152
+ raise HTTPException(status_code=403, detail="Insufficient permissions")
153
+
154
+ client_ip = get_client_ip(req)
155
+ account_service = AccountService()
156
+
157
+ result = await account_service.merge_social_accounts(
158
+ primary_user_id=user_id,
159
+ secondary_user_id=target_user_id,
160
+ client_ip=client_ip
161
+ )
162
+
163
+ return {"message": "Successfully merged social accounts", "result": result}
164
+ except ValueError as e:
165
+ logger.warning(f"Invalid merge request for user {user_id}: {str(e)}")
166
+ raise HTTPException(status_code=400, detail=str(e))
167
+ except Exception as e:
168
+ logger.error(f"Error merging social accounts for user {user_id}: {str(e)}")
169
+ raise HTTPException(status_code=500, detail="Failed to merge social accounts")
170
+
171
+ @router.delete("/revoke-all-sessions", response_model=dict)
172
+ async def revoke_all_sessions(
173
+ req: Request,
174
+ user_id: str = Depends(get_current_user)
175
+ ):
176
+ """Revoke all active sessions for security purposes"""
177
+ try:
178
+ client_ip = get_client_ip(req)
179
+ account_service = AccountService()
180
+
181
+ result = await account_service.revoke_all_sessions(user_id, client_ip)
182
+
183
+ return {"message": "All sessions have been revoked", "result": result}
184
+ except Exception as e:
185
+ logger.error(f"Error revoking sessions for user {user_id}: {str(e)}")
186
+ raise HTTPException(status_code=500, detail="Failed to revoke sessions")
187
+
188
+ @router.get("/trusted-devices", response_model=dict)
189
+ async def get_trusted_devices(user_id: str = Depends(get_current_user)):
190
+ """Get list of trusted devices for the current user"""
191
+ try:
192
+ account_service = AccountService()
193
+
194
+ devices = await account_service.get_trusted_devices(user_id)
195
+
196
+ return {"devices": devices}
197
+ except Exception as e:
198
+ logger.error(f"Error fetching trusted devices for user {user_id}: {str(e)}")
199
+ raise HTTPException(status_code=500, detail="Failed to fetch trusted devices")
200
+
201
+ @router.delete("/trusted-devices/{device_id}", response_model=dict)
202
+ async def remove_trusted_device(
203
+ device_id: str,
204
+ user_id: str = Depends(get_current_user)
205
+ ):
206
+ """Remove a trusted device"""
207
+ try:
208
+ account_service = AccountService()
209
+
210
+ result = await account_service.remove_trusted_device(user_id, device_id)
211
+
212
+ return {"message": "Trusted device removed successfully", "result": result}
213
+ except ValueError as e:
214
+ logger.warning(f"Invalid device removal request for user {user_id}: {str(e)}")
215
+ raise HTTPException(status_code=400, detail=str(e))
216
+ except Exception as e:
217
+ logger.error(f"Error removing trusted device for user {user_id}: {str(e)}")
218
+ raise HTTPException(status_code=500, detail="Failed to remove trusted device")
app/routers/user_router.py CHANGED
@@ -12,8 +12,10 @@ from app.schemas.user_schema import (
12
  )
13
  from app.services.user_service import UserService
14
  from app.utils.jwt import create_temp_token, decode_token
15
- from app.utils.social_utils import verify_google_token, verify_apple_token
16
  from app.utils.common_utils import validate_identifier
 
 
17
  import logging
18
 
19
  logger = logging.getLogger("user_router")
@@ -187,23 +189,84 @@ async def otp_login_handler(
187
 
188
  # 🌐 OAuth Login for Google / Apple
189
  @router.post("/oauth-login", response_model=TokenResponse)
190
- async def oauth_login_handler(payload: OAuthLoginRequest):
191
- if payload.provider == "google":
192
- user_info = await verify_google_token(payload.token)
193
- user_id = f"google_{user_info['id']}"
194
- elif payload.provider == "apple":
195
- user_info = await verify_apple_token(payload.token)
196
- user_id = f"apple_{user_info['id']}"
197
- else:
198
- raise HTTPException(status_code=400, detail="Unsupported OAuth provider")
199
-
200
- temp_token = create_temp_token({
201
- "sub": user_id,
202
- "type": "oauth_session",
203
- "verified": True
204
- }, expires_minutes=10)
205
-
206
- return {"access_token": temp_token, "token_type": "bearer"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # 👤 Final user registration after OTP or OAuth
209
  @router.post("/register", response_model=TokenResponse)
 
12
  )
13
  from app.services.user_service import UserService
14
  from app.utils.jwt import create_temp_token, decode_token
15
+ from app.utils.social_utils import verify_google_token, verify_apple_token, verify_facebook_token
16
  from app.utils.common_utils import validate_identifier
17
+ from app.models.social_security_model import SocialSecurityModel
18
+ from fastapi import Request
19
  import logging
20
 
21
  logger = logging.getLogger("user_router")
 
189
 
190
  # 🌐 OAuth Login for Google / Apple
191
  @router.post("/oauth-login", response_model=TokenResponse)
192
+ async def oauth_login_handler(payload: OAuthLoginRequest, request: Request):
193
+ from app.core.config import settings
194
+
195
+ # Get client IP
196
+ client_ip = request.client.host if request.client else None
197
+
198
+ # Check if IP is locked for this provider
199
+ if await SocialSecurityModel.is_oauth_ip_locked(client_ip, payload.provider):
200
+ await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
201
+ raise HTTPException(
202
+ status_code=429,
203
+ detail=f"Too many failed attempts. IP temporarily locked for {payload.provider} OAuth."
204
+ )
205
+
206
+ # Check rate limiting
207
+ if not await SocialSecurityModel.check_oauth_rate_limit(client_ip, payload.provider):
208
+ await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
209
+ raise HTTPException(
210
+ status_code=429,
211
+ detail=f"Rate limit exceeded for {payload.provider} OAuth. Please try again later."
212
+ )
213
+
214
+ # Validate token format
215
+ if not await SocialSecurityModel.validate_oauth_token_format(payload.token, payload.provider):
216
+ await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider)
217
+ await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
218
+ raise HTTPException(status_code=400, detail="Invalid token format")
219
+
220
+ # Increment rate limit counter
221
+ await SocialSecurityModel.increment_oauth_rate_limit(client_ip, payload.provider)
222
+
223
+ try:
224
+ if payload.provider == "google":
225
+ if not settings.GOOGLE_CLIENT_ID:
226
+ raise HTTPException(status_code=500, detail="Google OAuth not configured")
227
+ user_info = await verify_google_token(payload.token, settings.GOOGLE_CLIENT_ID)
228
+ user_id = f"google_{user_info.get('sub', user_info.get('id'))}"
229
+
230
+ elif payload.provider == "apple":
231
+ if not settings.APPLE_AUDIENCE:
232
+ raise HTTPException(status_code=500, detail="Apple OAuth not configured")
233
+ user_info = await verify_apple_token(payload.token, settings.APPLE_AUDIENCE)
234
+ user_id = f"apple_{user_info.get('sub', user_info.get('id'))}"
235
+
236
+ elif payload.provider == "facebook":
237
+ if not settings.FACEBOOK_APP_ID or not settings.FACEBOOK_APP_SECRET:
238
+ raise HTTPException(status_code=500, detail="Facebook OAuth not configured")
239
+ user_info = await verify_facebook_token(payload.token, settings.FACEBOOK_APP_ID, settings.FACEBOOK_APP_SECRET)
240
+ user_id = f"facebook_{user_info.get('id')}"
241
+
242
+ else:
243
+ raise HTTPException(status_code=400, detail="Unsupported OAuth provider")
244
+
245
+ # Clear failed attempts on successful verification
246
+ await SocialSecurityModel.clear_oauth_failed_attempts(client_ip, payload.provider)
247
+
248
+ # Log successful attempt
249
+ await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, True, user_id)
250
+
251
+ temp_token = create_temp_token({
252
+ "sub": user_id,
253
+ "type": "oauth_session",
254
+ "verified": True,
255
+ "provider": payload.provider,
256
+ "user_info": user_info
257
+ }, expires_minutes=10)
258
+
259
+ return {"access_token": temp_token, "token_type": "bearer"}
260
+
261
+ except HTTPException:
262
+ # Re-raise HTTP exceptions (configuration errors, etc.)
263
+ raise
264
+ except Exception as e:
265
+ # Track failed attempt for token verification failures
266
+ await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider)
267
+ await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False)
268
+ logger.error(f"OAuth verification failed for {payload.provider}: {str(e)}", exc_info=True)
269
+ raise HTTPException(status_code=401, detail="OAuth token verification failed")
270
 
271
  # 👤 Final user registration after OTP or OAuth
272
  @router.post("/register", response_model=TokenResponse)
app/schemas/user_schema.py CHANGED
@@ -1,17 +1,18 @@
1
  from pydantic import BaseModel, EmailStr, validator
2
- from typing import Optional, Literal
 
3
  import re
4
 
5
  # Used for OTP-based or OAuth-based user registration
6
  class UserRegisterRequest(BaseModel):
7
  name: str
8
- email: Optional[EmailStr] = None
9
- phone: Optional[str] = None
10
- otpIdentifer: Optional[str] = None # email or phone
11
  otp: Optional[str] = None
12
  dob: Optional[str] = None # ISO format date string
13
  oauth_token: Optional[str] = None
14
- provider: Optional[Literal["google", "apple"]] = None
15
  mode: Literal["otp", "oauth"]
16
 
17
  @validator('phone')
@@ -108,18 +109,77 @@ class OTPVerifyRequest(BaseModel):
108
 
109
  # OAuth login using Google/Apple
110
  class OAuthLoginRequest(BaseModel):
111
- provider: Literal["google", "apple"]
112
  token: str
113
 
114
- # JWT Token response format
115
  class TokenResponse(BaseModel):
116
  access_token: str
117
  token_type: str = "bearer"
118
- name: str = None # Added for OTP login response
 
 
 
 
 
 
 
119
 
120
- # Optional: profile info response post-login
121
  class UserProfileResponse(BaseModel):
122
  user_id: str
123
- full_name: str
124
  email: Optional[EmailStr] = None
125
- phone: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel, EmailStr, validator
2
+ from typing import Optional, Literal, List, Dict, Any
3
+ from datetime import datetime
4
  import re
5
 
6
  # Used for OTP-based or OAuth-based user registration
7
  class UserRegisterRequest(BaseModel):
8
  name: str
9
+ email: EmailStr # Mandatory for all registration modes
10
+ phone: str # Mandatory for all registration modes (always used as OTP identifier)
11
+ otpIdentifer: Optional[str] = None # Deprecated - phone is always the OTP identifier
12
  otp: Optional[str] = None
13
  dob: Optional[str] = None # ISO format date string
14
  oauth_token: Optional[str] = None
15
+ provider: Optional[Literal["google", "apple", "facebook"]] = None
16
  mode: Literal["otp", "oauth"]
17
 
18
  @validator('phone')
 
109
 
110
  # OAuth login using Google/Apple
111
  class OAuthLoginRequest(BaseModel):
112
+ provider: Literal["google", "apple", "facebook"]
113
  token: str
114
 
115
+ # JWT Token response format with enhanced security info
116
  class TokenResponse(BaseModel):
117
  access_token: str
118
  token_type: str = "bearer"
119
+ expires_in: Optional[int] = 28800 # 8 hours in seconds
120
+ user_id: Optional[str] = None
121
+ name: Optional[str] = None
122
+ email: Optional[str] = None
123
+ profile_picture: Optional[str] = None
124
+ auth_method: Optional[str] = None # "otp" or "oauth"
125
+ provider: Optional[str] = None # For OAuth logins
126
+ security_info: Optional[Dict[str, Any]] = None
127
 
128
+ # Enhanced user profile response with social accounts
129
  class UserProfileResponse(BaseModel):
130
  user_id: str
131
+ name: str
132
  email: Optional[EmailStr] = None
133
+ phone: Optional[str] = None
134
+ profile_picture: Optional[str] = None
135
+ auth_method: str
136
+ created_at: datetime
137
+ social_accounts: Optional[List[Dict[str, Any]]] = None
138
+ security_info: Optional[Dict[str, Any]] = None
139
+
140
+ # Social account information
141
+ class SocialAccountInfo(BaseModel):
142
+ provider: str
143
+ email: Optional[str] = None
144
+ name: Optional[str] = None
145
+ linked_at: datetime
146
+ last_login: Optional[datetime] = None
147
+
148
+ # Social account summary response
149
+ class SocialAccountSummary(BaseModel):
150
+ linked_accounts: List[SocialAccountInfo]
151
+ total_accounts: int
152
+ profile_picture: Optional[str] = None
153
+
154
+ # Account linking request
155
+ class LinkSocialAccountRequest(BaseModel):
156
+ provider: Literal["google", "apple", "facebook"]
157
+ token: str
158
+
159
+ # Account unlinking request
160
+ class UnlinkSocialAccountRequest(BaseModel):
161
+ provider: Literal["google", "apple", "facebook"]
162
+
163
+ # Login history entry
164
+ class LoginHistoryEntry(BaseModel):
165
+ timestamp: datetime
166
+ method: str # "otp" or "oauth"
167
+ provider: Optional[str] = None
168
+ ip_address: Optional[str] = None
169
+ success: bool
170
+ device_info: Optional[str] = None
171
+
172
+ # Login history response
173
+ class LoginHistoryResponse(BaseModel):
174
+ entries: List[LoginHistoryEntry]
175
+ total_entries: int
176
+ page: int
177
+ per_page: int
178
+
179
+ # Security settings response
180
+ class SecuritySettingsResponse(BaseModel):
181
+ two_factor_enabled: bool = False
182
+ linked_social_accounts: int
183
+ last_password_change: Optional[datetime] = None
184
+ recent_login_attempts: int
185
+ account_locked: bool = False
app/services/account_service.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ from typing import List, Dict, Any, Optional
3
+ import logging
4
+ from bson import ObjectId
5
+
6
+ from app.models.social_account_model import SocialAccountModel
7
+ from app.models.user_model import BookMyServiceUserModel
8
+ from app.schemas.user_schema import (
9
+ SocialAccountSummary, SocialAccountInfo, LoginHistoryResponse,
10
+ LoginHistoryEntry, SecuritySettingsResponse
11
+ )
12
+ from app.utils.social_utils import verify_google_token, verify_apple_token, verify_facebook_token
13
+ from app.core.nosql_client import db
14
+
15
+ # Configure logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class AccountService:
19
+ """Service for managing user accounts, social accounts, and security settings"""
20
+
21
+ def __init__(self):
22
+ self.security_collection = db.get_collection("security_logs")
23
+ self.device_collection = db.get_collection("device_tracking")
24
+ self.session_collection = db.get_collection("user_sessions")
25
+
26
+ async def get_social_account_summary(self, user_id: str) -> SocialAccountSummary:
27
+ """Get summary of all linked social accounts for a user"""
28
+ try:
29
+ social_accounts = await SocialAccountModel.find_by_user_id(user_id)
30
+
31
+ linked_accounts = []
32
+ profile_picture = None
33
+
34
+ for account in social_accounts:
35
+ account_info = SocialAccountInfo(
36
+ provider=account["provider"],
37
+ email=account.get("email"),
38
+ name=account.get("name"),
39
+ linked_at=account["created_at"],
40
+ last_login=account.get("last_login")
41
+ )
42
+ linked_accounts.append(account_info)
43
+
44
+ # Use the first available profile picture
45
+ if not profile_picture and account.get("profile_picture"):
46
+ profile_picture = account["profile_picture"]
47
+
48
+ return SocialAccountSummary(
49
+ linked_accounts=linked_accounts,
50
+ total_accounts=len(linked_accounts),
51
+ profile_picture=profile_picture
52
+ )
53
+
54
+ except Exception as e:
55
+ logger.error(f"Error getting social account summary for user {user_id}: {str(e)}")
56
+ raise
57
+
58
+ async def link_social_account(self, user_id: str, provider: str, token: str, client_ip: str) -> Dict[str, Any]:
59
+ """Link a new social account to an existing user"""
60
+ try:
61
+ # Verify the token and get user info
62
+ user_info = await self._verify_social_token(provider, token)
63
+
64
+ # Check if this social account is already linked to another user
65
+ existing_account = await SocialAccountModel.find_by_provider_id(
66
+ provider, user_info["id"]
67
+ )
68
+
69
+ if existing_account and existing_account["user_id"] != user_id:
70
+ raise ValueError(f"This {provider} account is already linked to another user")
71
+
72
+ # Check if user already has this provider linked
73
+ user_provider_account = await SocialAccountModel.find_by_user_and_provider(
74
+ user_id, provider
75
+ )
76
+
77
+ if user_provider_account:
78
+ # Update existing account
79
+ await SocialAccountModel.update_social_account(
80
+ user_id, provider, user_info, client_ip
81
+ )
82
+ action = "updated"
83
+ else:
84
+ # Create new social account link
85
+ await SocialAccountModel.create_social_account(
86
+ user_id, provider, user_info, client_ip
87
+ )
88
+ action = "linked"
89
+
90
+ # Log the action
91
+ await self._log_account_action(
92
+ user_id, f"social_account_{action}",
93
+ {"provider": provider, "client_ip": client_ip}
94
+ )
95
+
96
+ return {"action": action, "provider": provider, "user_info": user_info}
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error linking social account for user {user_id}: {str(e)}")
100
+ raise
101
+
102
+ async def unlink_social_account(self, user_id: str, provider: str) -> Dict[str, Any]:
103
+ """Unlink a social account from a user"""
104
+ try:
105
+ # Check if account exists
106
+ account = await SocialAccountModel.find_by_user_and_provider(user_id, provider)
107
+ if not account:
108
+ raise ValueError(f"No {provider} account found for this user")
109
+
110
+ # Check if this is the only authentication method
111
+ user = await BookMyServiceUserModel.find_by_id(user_id)
112
+ if not user:
113
+ raise ValueError("User not found")
114
+
115
+ # Count total social accounts
116
+ social_accounts = await SocialAccountModel.find_by_user_id(user_id)
117
+
118
+ # If user has no phone/email and this is their only social account, prevent unlinking
119
+ if (len(social_accounts) == 1 and
120
+ not user.get("phone") and not user.get("email")):
121
+ raise ValueError("Cannot unlink the only authentication method")
122
+
123
+ # Unlink the account
124
+ result = await SocialAccountModel.unlink_social_account(user_id, provider)
125
+
126
+ # Log the action
127
+ await self._log_account_action(
128
+ user_id, "social_account_unlinked",
129
+ {"provider": provider}
130
+ )
131
+
132
+ return {"action": "unlinked", "provider": provider, "result": result}
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error unlinking social account for user {user_id}: {str(e)}")
136
+ raise
137
+
138
+ async def get_login_history(self, user_id: str, page: int = 1,
139
+ per_page: int = 10, days: int = 30) -> LoginHistoryResponse:
140
+ """Get login history for a user"""
141
+ try:
142
+ # Calculate date range
143
+ end_date = datetime.utcnow()
144
+ start_date = end_date - timedelta(days=days)
145
+
146
+ # Query security logs for login events
147
+ skip = (page - 1) * per_page
148
+
149
+ pipeline = [
150
+ {
151
+ "$match": {
152
+ "user_id": user_id,
153
+ "timestamp": {"$gte": start_date, "$lte": end_date},
154
+ "$or": [
155
+ {"path": {"$regex": "/login"}},
156
+ {"path": {"$regex": "/oauth"}},
157
+ {"path": {"$regex": "/otp"}}
158
+ ]
159
+ }
160
+ },
161
+ {"$sort": {"timestamp": -1}},
162
+ {"$skip": skip},
163
+ {"$limit": per_page}
164
+ ]
165
+
166
+ cursor = self.security_collection.aggregate(pipeline)
167
+ logs = await cursor.to_list(length=per_page)
168
+
169
+ # Count total entries
170
+ total_count = await self.security_collection.count_documents({
171
+ "user_id": user_id,
172
+ "timestamp": {"$gte": start_date, "$lte": end_date},
173
+ "$or": [
174
+ {"path": {"$regex": "/login"}},
175
+ {"path": {"$regex": "/oauth"}},
176
+ {"path": {"$regex": "/otp"}}
177
+ ]
178
+ })
179
+
180
+ # Convert to response format
181
+ entries = []
182
+ for log in logs:
183
+ method = "oauth" if "oauth" in log["path"] else "otp"
184
+ provider = None
185
+
186
+ # Extract provider from query params if available
187
+ if method == "oauth" and log.get("query_params"):
188
+ provider = log["query_params"].get("provider")
189
+
190
+ entry = LoginHistoryEntry(
191
+ timestamp=log["timestamp"],
192
+ method=method,
193
+ provider=provider,
194
+ ip_address=log.get("client_ip"),
195
+ success=log["status_code"] < 400,
196
+ device_info=log.get("device_info", {}).get("user_agent")
197
+ )
198
+ entries.append(entry)
199
+
200
+ return LoginHistoryResponse(
201
+ entries=entries,
202
+ total_entries=total_count,
203
+ page=page,
204
+ per_page=per_page
205
+ )
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error getting login history for user {user_id}: {str(e)}")
209
+ raise
210
+
211
+ async def get_security_settings(self, user_id: str) -> SecuritySettingsResponse:
212
+ """Get security settings and status for a user"""
213
+ try:
214
+ # Get user info
215
+ user = await BookMyServiceUserModel.find_by_id(user_id)
216
+ if not user:
217
+ raise ValueError("User not found")
218
+
219
+ # Count linked social accounts
220
+ social_accounts = await SocialAccountModel.find_by_user_id(user_id)
221
+ linked_accounts_count = len(social_accounts)
222
+
223
+ # Get recent login attempts (last 24 hours)
224
+ yesterday = datetime.utcnow() - timedelta(days=1)
225
+ recent_attempts = await self.security_collection.count_documents({
226
+ "user_id": user_id,
227
+ "timestamp": {"$gte": yesterday},
228
+ "$or": [
229
+ {"path": {"$regex": "/login"}},
230
+ {"path": {"$regex": "/oauth"}},
231
+ {"path": {"$regex": "/otp"}}
232
+ ]
233
+ })
234
+
235
+ # Check if account is locked (this would be implemented based on your locking logic)
236
+ account_locked = False # Implement based on your account locking mechanism
237
+
238
+ return SecuritySettingsResponse(
239
+ two_factor_enabled=False, # Implement 2FA if needed
240
+ linked_social_accounts=linked_accounts_count,
241
+ last_password_change=None, # Implement if you have password functionality
242
+ recent_login_attempts=recent_attempts,
243
+ account_locked=account_locked
244
+ )
245
+
246
+ except Exception as e:
247
+ logger.error(f"Error getting security settings for user {user_id}: {str(e)}")
248
+ raise
249
+
250
+ async def merge_social_accounts(self, primary_user_id: str, secondary_user_id: str,
251
+ client_ip: str) -> Dict[str, Any]:
252
+ """Merge social accounts from secondary user to primary user"""
253
+ try:
254
+ # Get social accounts from secondary user
255
+ secondary_accounts = await SocialAccountModel.find_by_user_id(secondary_user_id)
256
+
257
+ merged_count = 0
258
+ for account in secondary_accounts:
259
+ # Check if primary user already has this provider
260
+ existing = await SocialAccountModel.find_by_user_and_provider(
261
+ primary_user_id, account["provider"]
262
+ )
263
+
264
+ if not existing:
265
+ # Transfer the account to primary user
266
+ await SocialAccountModel.update_user_id(
267
+ account["_id"], primary_user_id
268
+ )
269
+ merged_count += 1
270
+
271
+ # Log the merge action
272
+ await self._log_account_action(
273
+ primary_user_id, "accounts_merged",
274
+ {
275
+ "secondary_user_id": secondary_user_id,
276
+ "merged_accounts": merged_count,
277
+ "client_ip": client_ip
278
+ }
279
+ )
280
+
281
+ return {
282
+ "merged_accounts": merged_count,
283
+ "primary_user_id": primary_user_id,
284
+ "secondary_user_id": secondary_user_id
285
+ }
286
+
287
+ except Exception as e:
288
+ logger.error(f"Error merging accounts {secondary_user_id} -> {primary_user_id}: {str(e)}")
289
+ raise
290
+
291
+ async def revoke_all_sessions(self, user_id: str, client_ip: str) -> Dict[str, Any]:
292
+ """Revoke all active sessions for a user"""
293
+ try:
294
+ # In a real implementation, you'd have a sessions collection
295
+ # For now, we'll just log the action
296
+ await self._log_account_action(
297
+ user_id, "all_sessions_revoked",
298
+ {"client_ip": client_ip}
299
+ )
300
+
301
+ # Here you would typically:
302
+ # 1. Delete all session tokens from database
303
+ # 2. Add tokens to a blacklist
304
+ # 3. Force re-authentication on next request
305
+
306
+ return {"action": "revoked", "user_id": user_id}
307
+
308
+ except Exception as e:
309
+ logger.error(f"Error revoking sessions for user {user_id}: {str(e)}")
310
+ raise
311
+
312
+ async def get_trusted_devices(self, user_id: str) -> List[Dict[str, Any]]:
313
+ """Get list of trusted devices for a user"""
314
+ try:
315
+ cursor = self.device_collection.find({
316
+ "user_id": user_id,
317
+ "is_trusted": True
318
+ }).sort("last_seen", -1)
319
+
320
+ devices = await cursor.to_list(length=None)
321
+
322
+ # Format device information
323
+ trusted_devices = []
324
+ for device in devices:
325
+ device_info = {
326
+ "device_id": str(device["_id"]),
327
+ "device_fingerprint": device["device_fingerprint"],
328
+ "platform": device.get("device_info", {}).get("platform", "Unknown"),
329
+ "browser": device.get("device_info", {}).get("browser", "Unknown"),
330
+ "first_seen": device["first_seen"],
331
+ "last_seen": device["last_seen"],
332
+ "access_count": device.get("access_count", 0)
333
+ }
334
+ trusted_devices.append(device_info)
335
+
336
+ return trusted_devices
337
+
338
+ except Exception as e:
339
+ logger.error(f"Error getting trusted devices for user {user_id}: {str(e)}")
340
+ raise
341
+
342
+ async def remove_trusted_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
343
+ """Remove a trusted device"""
344
+ try:
345
+ result = await self.device_collection.update_one(
346
+ {
347
+ "_id": ObjectId(device_id),
348
+ "user_id": user_id
349
+ },
350
+ {"$set": {"is_trusted": False}}
351
+ )
352
+
353
+ if result.matched_count == 0:
354
+ raise ValueError("Device not found or not owned by user")
355
+
356
+ await self._log_account_action(
357
+ user_id, "trusted_device_removed",
358
+ {"device_id": device_id}
359
+ )
360
+
361
+ return {"action": "removed", "device_id": device_id}
362
+
363
+ except Exception as e:
364
+ logger.error(f"Error removing trusted device for user {user_id}: {str(e)}")
365
+ raise
366
+
367
+ async def _verify_social_token(self, provider: str, token: str) -> Dict[str, Any]:
368
+ """Verify social media token and return user info"""
369
+ try:
370
+ if provider == "google":
371
+ return await verify_google_token(token)
372
+ elif provider == "apple":
373
+ return await verify_apple_token(token)
374
+ elif provider == "facebook":
375
+ return await verify_facebook_token(token)
376
+ else:
377
+ raise ValueError(f"Unsupported provider: {provider}")
378
+ except Exception as e:
379
+ logger.error(f"Token verification failed for {provider}: {str(e)}")
380
+ raise ValueError(f"Invalid {provider} token")
381
+
382
+ async def _log_account_action(self, user_id: str, action: str, details: Dict[str, Any]):
383
+ """Log account-related actions for audit purposes"""
384
+ try:
385
+ log_entry = {
386
+ "timestamp": datetime.utcnow(),
387
+ "user_id": user_id,
388
+ "action": action,
389
+ "details": details,
390
+ "type": "account_management"
391
+ }
392
+
393
+ await self.security_collection.insert_one(log_entry)
394
+
395
+ except Exception as e:
396
+ logger.error(f"Failed to log account action: {str(e)}")
app/services/user_service.py CHANGED
@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
4
  from fastapi import HTTPException
5
  from app.models.user_model import BookMyServiceUserModel
6
  from app.models.otp_model import BookMyServiceOTPModel
 
7
  from app.core.config import settings
8
  from app.utils.common_utils import is_email, validate_identifier
9
  from app.schemas.user_schema import UserRegisterRequest
@@ -11,16 +12,27 @@ import logging
11
 
12
  logger = logging.getLogger("user_service")
13
 
 
 
 
14
  class UserService:
15
  @staticmethod
16
- async def send_otp(identifier: str, phone: str = None):
17
- logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}")
18
 
19
  try:
20
  # Validate identifier format
21
  identifier_type = validate_identifier(identifier)
22
  logger.debug(f"Identifier type: {identifier_type}")
23
 
 
 
 
 
 
 
 
 
24
  # For phone identifiers, use the identifier itself as phone
25
  # For email identifiers, use the provided phone parameter
26
  if identifier_type == "phone":
@@ -31,13 +43,19 @@ class UserService:
31
  # If email identifier but no phone provided, we'll send OTP via email
32
  phone_number = None
33
 
34
- # Using dummy OTP for testing
35
- otp = "777777"
36
- logger.debug(f"Generated OTP: {otp} for identifier: {identifier}")
 
37
 
38
  await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp)
 
 
 
 
 
39
  logger.info(f"OTP stored successfully for identifier: {identifier}")
40
- logger.debug(f"OTP sent to {identifier}: {otp}")
41
 
42
  except ValueError as ve:
43
  logger.error(f"Validation error for identifier {identifier}: {str(ve)}")
@@ -47,23 +65,32 @@ class UserService:
47
  raise HTTPException(status_code=500, detail="Failed to send OTP")
48
 
49
  @staticmethod
50
- async def otp_login_handler(identifier: str, otp: str):
51
- logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}")
52
 
53
  try:
54
  # Validate identifier format
55
  identifier_type = validate_identifier(identifier)
56
  logger.debug(f"Identifier type: {identifier_type}")
57
 
58
- # Verify OTP
 
 
 
 
 
59
  logger.debug(f"Verifying OTP for identifier: {identifier}")
60
- otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp)
61
  logger.debug(f"OTP verification result: {otp_valid}")
62
 
63
  if not otp_valid:
64
  logger.warning(f"Invalid or expired OTP for identifier: {identifier}")
 
 
65
  raise HTTPException(status_code=400, detail="Invalid or expired OTP")
66
 
 
 
67
  logger.info(f"OTP verification successful for identifier: {identifier}")
68
 
69
  # Find user by identifier
@@ -108,17 +135,28 @@ class UserService:
108
  async def register(data: UserRegisterRequest, decoded):
109
  logger.info(f"Registering user with data: {data}")
110
 
111
- if data.mode == "otp":
112
- identifier = data.otpIdentifer or decoded.get("sub")
113
- if not identifier:
114
- raise HTTPException(status_code=400, detail="Missing verified identifier")
 
 
 
 
 
115
 
116
- # Validate identifier format
 
 
 
 
117
  try:
118
  identifier_type = validate_identifier(identifier)
 
 
119
  logger.debug(f"Registration identifier type: {identifier_type}")
120
  except ValueError as ve:
121
- logger.error(f"Invalid identifier format during registration: {str(ve)}")
122
  raise HTTPException(status_code=400, detail=str(ve))
123
 
124
  redis_key = f"bms_otp:{identifier}"
@@ -133,9 +171,44 @@ class UserService:
133
  user_id = f"otp_{identifier}"
134
 
135
  elif data.mode == "oauth":
 
136
  if not data.oauth_token or not data.provider:
137
- raise HTTPException(status_code=400, detail="OAuth token and provider required")
138
- user_id = f"{data.provider}_{data.oauth_token}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  else:
140
  raise HTTPException(status_code=400, detail="Unsupported registration mode")
141
 
@@ -151,6 +224,7 @@ class UserService:
151
  if existing_user:
152
  raise HTTPException(status_code=409, detail="User with this email or phone already exists")
153
 
 
154
  user_doc = {
155
  "user_id": user_id,
156
  "name": data.name,
@@ -159,7 +233,20 @@ class UserService:
159
  "auth_mode": data.mode,
160
  "created_at": datetime.utcnow()
161
  }
 
 
 
 
 
162
  await BookMyServiceUserModel.collection.insert_one(user_doc)
 
 
 
 
 
 
 
 
163
 
164
  token_data = {
165
  "sub": user_id,
 
4
  from fastapi import HTTPException
5
  from app.models.user_model import BookMyServiceUserModel
6
  from app.models.otp_model import BookMyServiceOTPModel
7
+ from app.models.social_account_model import SocialAccountModel
8
  from app.core.config import settings
9
  from app.utils.common_utils import is_email, validate_identifier
10
  from app.schemas.user_schema import UserRegisterRequest
 
12
 
13
  logger = logging.getLogger("user_service")
14
 
15
+
16
+
17
+
18
  class UserService:
19
  @staticmethod
20
+ async def send_otp(identifier: str, phone: str = None, client_ip: str = None):
21
+ logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}, ip: {client_ip}")
22
 
23
  try:
24
  # Validate identifier format
25
  identifier_type = validate_identifier(identifier)
26
  logger.debug(f"Identifier type: {identifier_type}")
27
 
28
+ # Enhanced rate limiting by IP and identifier
29
+ if client_ip:
30
+ ip_rate_key = f"otp_ip_rate:{client_ip}"
31
+ ip_attempts = await BookMyServiceOTPModel.get_rate_limit_count(ip_rate_key)
32
+ if ip_attempts >= 10: # Max 10 OTPs per IP per hour
33
+ logger.warning(f"IP rate limit exceeded for {client_ip}")
34
+ raise HTTPException(status_code=429, detail="Too many OTP requests from this IP")
35
+
36
  # For phone identifiers, use the identifier itself as phone
37
  # For email identifiers, use the provided phone parameter
38
  if identifier_type == "phone":
 
43
  # If email identifier but no phone provided, we'll send OTP via email
44
  phone_number = None
45
 
46
+ # Generate secure OTP (6 digits, cryptographically secure)
47
+ import secrets
48
+ otp = ''.join([str(secrets.randbelow(10)) for _ in range(6)])
49
+ logger.debug(f"Generated secure OTP for identifier: {identifier}")
50
 
51
  await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp)
52
+
53
+ # Track IP-based rate limiting
54
+ if client_ip:
55
+ await BookMyServiceOTPModel.increment_rate_limit(ip_rate_key, 3600) # 1 hour window
56
+
57
  logger.info(f"OTP stored successfully for identifier: {identifier}")
58
+ logger.debug(f"OTP sent to {identifier}")
59
 
60
  except ValueError as ve:
61
  logger.error(f"Validation error for identifier {identifier}: {str(ve)}")
 
65
  raise HTTPException(status_code=500, detail="Failed to send OTP")
66
 
67
  @staticmethod
68
+ async def otp_login_handler(identifier: str, otp: str, client_ip: str = None):
69
+ logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}, ip: {client_ip}")
70
 
71
  try:
72
  # Validate identifier format
73
  identifier_type = validate_identifier(identifier)
74
  logger.debug(f"Identifier type: {identifier_type}")
75
 
76
+ # Check if account is locked
77
+ if await BookMyServiceOTPModel.is_account_locked(identifier):
78
+ logger.warning(f"Account locked for identifier: {identifier}")
79
+ raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts")
80
+
81
+ # Verify OTP with client IP tracking
82
  logger.debug(f"Verifying OTP for identifier: {identifier}")
83
+ otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp, client_ip)
84
  logger.debug(f"OTP verification result: {otp_valid}")
85
 
86
  if not otp_valid:
87
  logger.warning(f"Invalid or expired OTP for identifier: {identifier}")
88
+ # Track failed attempt
89
+ await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip)
90
  raise HTTPException(status_code=400, detail="Invalid or expired OTP")
91
 
92
+ # Clear failed attempts on successful verification
93
+ await BookMyServiceOTPModel.clear_failed_attempts(identifier)
94
  logger.info(f"OTP verification successful for identifier: {identifier}")
95
 
96
  # Find user by identifier
 
135
  async def register(data: UserRegisterRequest, decoded):
136
  logger.info(f"Registering user with data: {data}")
137
 
138
+ # Validate mandatory fields for all registration modes
139
+ if not data.name or not data.name.strip():
140
+ raise HTTPException(status_code=400, detail="Name is required")
141
+
142
+ if not data.email:
143
+ raise HTTPException(status_code=400, detail="Email is required")
144
+
145
+ if not data.phone or not data.phone.strip():
146
+ raise HTTPException(status_code=400, detail="Phone is required")
147
 
148
+ if data.mode == "otp":
149
+ # Always use phone as the OTP identifier as per documentation
150
+ identifier = data.phone
151
+
152
+ # Validate phone format
153
  try:
154
  identifier_type = validate_identifier(identifier)
155
+ if identifier_type != "phone":
156
+ raise ValueError("Phone number format is invalid")
157
  logger.debug(f"Registration identifier type: {identifier_type}")
158
  except ValueError as ve:
159
+ logger.error(f"Invalid phone format during registration: {str(ve)}")
160
  raise HTTPException(status_code=400, detail=str(ve))
161
 
162
  redis_key = f"bms_otp:{identifier}"
 
171
  user_id = f"otp_{identifier}"
172
 
173
  elif data.mode == "oauth":
174
+ # Validate OAuth-specific mandatory fields
175
  if not data.oauth_token or not data.provider:
176
+ raise HTTPException(status_code=400, detail="OAuth token and provider are required")
177
+
178
+ # Extract user info from decoded token
179
+ user_info = decoded.get("user_info", {})
180
+ provider_user_id = user_info.get("sub") or user_info.get("id")
181
+
182
+ if not provider_user_id:
183
+ raise HTTPException(status_code=400, detail="Invalid OAuth user information")
184
+
185
+ # Check if this social account already exists
186
+ existing_social_account = await SocialAccountModel.find_by_provider_and_user_id(
187
+ data.provider, provider_user_id
188
+ )
189
+
190
+ if existing_social_account:
191
+ # User already has this social account linked
192
+ existing_user = await BookMyServiceUserModel.collection.find_one({
193
+ "user_id": existing_social_account["user_id"]
194
+ })
195
+ if existing_user:
196
+ # Update social account with latest info and return existing user token
197
+ await SocialAccountModel.update_social_account(data.provider, provider_user_id, user_info)
198
+
199
+ token_data = {
200
+ "sub": existing_user["user_id"],
201
+ "user_id": existing_user["user_id"],
202
+ "email": existing_user.get("email"),
203
+ "phone": existing_user.get("phone"),
204
+ "role": "user",
205
+ "exp": datetime.utcnow() + timedelta(hours=8)
206
+ }
207
+ access_token = jwt.encode(token_data, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
208
+ return {"access_token": access_token}
209
+
210
+ user_id = f"{data.provider}_{provider_user_id}"
211
+
212
  else:
213
  raise HTTPException(status_code=400, detail="Unsupported registration mode")
214
 
 
224
  if existing_user:
225
  raise HTTPException(status_code=409, detail="User with this email or phone already exists")
226
 
227
+ # Create user document
228
  user_doc = {
229
  "user_id": user_id,
230
  "name": data.name,
 
233
  "auth_mode": data.mode,
234
  "created_at": datetime.utcnow()
235
  }
236
+
237
+ # Add profile picture from social account if available
238
+ if data.mode == "oauth" and user_info.get("picture"):
239
+ user_doc["profile_picture"] = user_info["picture"]
240
+
241
  await BookMyServiceUserModel.collection.insert_one(user_doc)
242
+ logger.info(f"Created new user: {user_id}")
243
+
244
+ # Create social account record for OAuth registration
245
+ if data.mode == "oauth":
246
+ await SocialAccountModel.create_social_account(
247
+ user_id, data.provider, provider_user_id, user_info
248
+ )
249
+ logger.info(f"Created social account link for {data.provider}")
250
 
251
  token_data = {
252
  "sub": user_id,
app/utils/social_utils.py CHANGED
@@ -15,6 +15,62 @@ class TokenVerificationError(Exception):
15
  """Custom exception for token verification errors"""
16
  pass
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class GoogleTokenVerifier:
19
  def __init__(self, client_id: str):
20
  self.client_id = client_id
@@ -140,9 +196,11 @@ class AppleTokenVerifier:
140
 
141
  # Factory class for easier usage
142
  class OAuthVerifier:
143
- def __init__(self, google_client_id: Optional[str] = None, apple_audience: Optional[str] = None):
 
144
  self.google_verifier = GoogleTokenVerifier(google_client_id) if google_client_id else None
145
  self.apple_verifier = AppleTokenVerifier(apple_audience) if apple_audience else None
 
146
 
147
  async def verify_google_token(self, token: str) -> Dict:
148
  if not self.google_verifier:
@@ -153,6 +211,11 @@ class OAuthVerifier:
153
  if not self.apple_verifier:
154
  raise TokenVerificationError("Apple verifier not configured")
155
  return await self.apple_verifier.verify_token(token)
 
 
 
 
 
156
 
157
  # Convenience functions (backward compatibility)
158
  async def verify_google_token(token: str, client_id: str) -> Dict:
@@ -169,6 +232,13 @@ async def verify_apple_token(token: str, audience: str) -> Dict:
169
  verifier = AppleTokenVerifier(audience)
170
  return await verifier.verify_token(token)
171
 
 
 
 
 
 
 
 
172
  # Example usage
173
  async def example_usage():
174
  # Initialize verifier
 
15
  """Custom exception for token verification errors"""
16
  pass
17
 
18
+ class FacebookTokenVerifier:
19
+ def __init__(self, app_id: str, app_secret: str):
20
+ self.app_id = app_id
21
+ self.app_secret = app_secret
22
+
23
+ async def verify_token(self, token: str) -> Dict:
24
+ """
25
+ Asynchronously verifies a Facebook access token and returns user data.
26
+ """
27
+ try:
28
+ # First, verify the token with Facebook's debug endpoint
29
+ async with httpx.AsyncClient(timeout=10.0) as client:
30
+ # Verify token validity
31
+ debug_url = f"https://graph.facebook.com/debug_token"
32
+ debug_params = {
33
+ "input_token": token,
34
+ "access_token": f"{self.app_id}|{self.app_secret}"
35
+ }
36
+
37
+ debug_response = await client.get(debug_url, params=debug_params)
38
+ debug_response.raise_for_status()
39
+ debug_data = debug_response.json()
40
+
41
+ if not debug_data.get("data", {}).get("is_valid"):
42
+ raise TokenVerificationError("Invalid Facebook token")
43
+
44
+ # Check if token is for our app
45
+ token_app_id = debug_data.get("data", {}).get("app_id")
46
+ if token_app_id != self.app_id:
47
+ raise TokenVerificationError("Token not for this app")
48
+
49
+ # Get user data
50
+ user_url = "https://graph.facebook.com/me"
51
+ user_params = {
52
+ "access_token": token,
53
+ "fields": "id,name,email,picture.type(large)"
54
+ }
55
+
56
+ user_response = await client.get(user_url, params=user_params)
57
+ user_response.raise_for_status()
58
+ user_data = user_response.json()
59
+
60
+ # Validate required fields
61
+ if not user_data.get("id"):
62
+ raise TokenVerificationError("Missing user ID in Facebook response")
63
+
64
+ logger.info(f"Successfully verified Facebook token for user: {user_data.get('email', user_data.get('id'))}")
65
+ return user_data
66
+
67
+ except httpx.RequestError as e:
68
+ logger.error(f"Facebook token verification request failed: {str(e)}")
69
+ raise TokenVerificationError(f"Facebook API request failed: {str(e)}")
70
+ except Exception as e:
71
+ logger.error(f"Facebook token verification failed: {str(e)}")
72
+ raise TokenVerificationError(f"Invalid Facebook token: {str(e)}")
73
+
74
  class GoogleTokenVerifier:
75
  def __init__(self, client_id: str):
76
  self.client_id = client_id
 
196
 
197
  # Factory class for easier usage
198
  class OAuthVerifier:
199
+ def __init__(self, google_client_id: Optional[str] = None, apple_audience: Optional[str] = None,
200
+ facebook_app_id: Optional[str] = None, facebook_app_secret: Optional[str] = None):
201
  self.google_verifier = GoogleTokenVerifier(google_client_id) if google_client_id else None
202
  self.apple_verifier = AppleTokenVerifier(apple_audience) if apple_audience else None
203
+ self.facebook_verifier = FacebookTokenVerifier(facebook_app_id, facebook_app_secret) if facebook_app_id and facebook_app_secret else None
204
 
205
  async def verify_google_token(self, token: str) -> Dict:
206
  if not self.google_verifier:
 
211
  if not self.apple_verifier:
212
  raise TokenVerificationError("Apple verifier not configured")
213
  return await self.apple_verifier.verify_token(token)
214
+
215
+ async def verify_facebook_token(self, token: str) -> Dict:
216
+ if not self.facebook_verifier:
217
+ raise TokenVerificationError("Facebook verifier not configured")
218
+ return await self.facebook_verifier.verify_token(token)
219
 
220
  # Convenience functions (backward compatibility)
221
  async def verify_google_token(token: str, client_id: str) -> Dict:
 
232
  verifier = AppleTokenVerifier(audience)
233
  return await verifier.verify_token(token)
234
 
235
+ async def verify_facebook_token(token: str, app_id: str, app_secret: str) -> Dict:
236
+ """
237
+ Asynchronously verifies a Facebook access token and returns user data.
238
+ """
239
+ verifier = FacebookTokenVerifier(app_id, app_secret)
240
+ return await verifier.verify_token(token)
241
+
242
  # Example usage
243
  async def example_usage():
244
  # Initialize verifier