diff --git a/app/app.py b/app/app.py index 6809945a460602c26c343578b03cd0bc767b9d0e..10f73ec549ea23b63d9212e8daa8489822aeb2be 100644 --- a/app/app.py +++ b/app/app.py @@ -4,13 +4,26 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from insightfy_utils.logging import setup_logging, get_logger from insightfy_utils.config import load_env -from app.routers.user import router as user_router from app.middleware.security_middleware import create_security_middleware +from app.middleware.rate_limiter import RateLimitMiddleware from app.config.validation import validate_environment_variables, validate_database_isolation import os import time import asyncio +# Import migrated routers +from app.routers import ( + user_router, + profile_router, + account_router, + wallet_router, + address_router, + pet_router, + guest_router, + favorite_router, + review_router +) + load_env() _env = os.getenv("ENVIRONMENT", "development") _log_level = os.getenv("LOG_LEVEL") or ("WARNING" if _env == "development" else "INFO") @@ -51,6 +64,9 @@ app.add_middleware( enable_security_headers=True ) +# Add rate limiting middleware +app.add_middleware(RateLimitMiddleware, calls=100, period=60) + allowed_origins = [ "http://localhost:3000", "http://localhost:3001", @@ -94,91 +110,36 @@ async def add_service_headers(request: Request, call_next): response.headers["X-Powered-By"] = "Insightfy" return response -app.include_router( - user_router, - prefix="/ums/v1/users", - tags=["Users"], - responses={404: {"description": "Not found"}, 500: {"description": "Internal error"}}, -) +# Include Routers +# Mapping source paths to /ums/v1/... convention -@app.get("/", tags=["Health"]) -async def root(): - return { - "service": "Insightfy BMS - User Management Service", - "version": "1.0.0", - "status": "running", - "description": "User Management Service API for Insightfy BMS platform", - "docs_url": "/docs", - "health_check": "/ums/health", - } - -@app.get("/ums/health", tags=["Health"]) -async def health_check(): - return {"status": "healthy", "service": "insightfy-bms-ms-ums", "version": "1.0.0", "timestamp": time.time()} - -@app.get("/ums/health/detailed", tags=["Health"]) -async def detailed_health_check(): - try: - health_status = await get_comprehensive_health_status() - return health_status - except Exception as e: - logger.error("detailed_health_failed", extra={"error": str(e), "service": "insightfy-bms-ms-ums"}, exc_info=True) - return {"status": "unhealthy", "service": "insightfy-bms-ms-ums", "version": "1.0.0", "error": "Health check failed", "timestamp": time.time()} +# Auth / User +app.include_router(user_router.router, prefix="/ums/v1/auth", tags=["User Auth"]) -@app.get("/ums/health/metrics", tags=["Health"]) -async def health_metrics(): - try: - from app.nosql import get_connection_metrics - connection_metrics = await get_connection_metrics() - return {"service": "insightfy-bms-ms-ums", "version": "1.0.0", "connection_metrics": connection_metrics, "status": "healthy", "timestamp": time.time(), "environment": os.getenv("ENVIRONMENT", "development")} - except Exception as e: - logger.error("metrics_failed", extra={"error": str(e), "service": "insightfy-bms-ms-ums"}, exc_info=True) - return {"service": "insightfy-bms-ms-ums", "version": "1.0.0", "error": "Failed to retrieve metrics", "status": "unhealthy", "timestamp": time.time()} +# Profile +app.include_router(profile_router.router, prefix="/ums/v1/profile", tags=["Profile"]) -@app.get("/ums/health/readiness", tags=["Health"]) -async def readiness_check(): - try: - from app.nosql import ping_mongo, ping_redis - mongo_ready = await ping_mongo() - redis_ready = await ping_redis() - if mongo_ready and redis_ready: - return {"status": "ready", "service": "insightfy-bms-ms-ums", "version": "1.0.0", "timestamp": time.time()} - else: - return JSONResponse(status_code=503, content={"status": "not_ready", "service": "insightfy-bms-ms-ums", "version": "1.0.0", "components": {"mongodb": "ready" if mongo_ready else "not_ready", "redis": "ready" if redis_ready else "not_ready"}, "timestamp": time.time()}) - except Exception as e: - logger.error("readiness_failed", extra={"error": str(e), "service": "insightfy-bms-ms-ums"}, exc_info=True) - return JSONResponse(status_code=503, content={"status": "not_ready", "service": "insightfy-bms-ms-ums", "error": "Readiness check failed", "timestamp": time.time()}) - -@app.get("/ums/health/liveness", tags=["Health"]) -async def liveness_check(): - current_time = time.time() - return {"status": "alive", "service": "insightfy-bms-ms-ums", "version": "1.0.0", "timestamp": current_time, "uptime_seconds": round(current_time - startup_time, 2)} - -async def get_comprehensive_health_status(): - from app.nosql import get_database_status - from app.config.validation import validate_environment_variables, validate_database_isolation - db_status = await get_database_status() - try: - config_status = validate_environment_variables() - config_valid = True - except Exception as e: - config_status = {"error": str(e)} - config_valid = False - isolation_valid = validate_database_isolation() - overall_healthy = (db_status["overall_status"] == "healthy" and config_valid and isolation_valid) - return { - "status": "healthy" if overall_healthy else "unhealthy", - "service": "insightfy-bms-ms-ums", - "version": "1.0.0", - "components": { - "mongodb": "healthy" if db_status["mongodb"]["connected"] else "unhealthy", - "redis": "healthy" if db_status["redis"]["connected"] else "unhealthy", - "configuration": "healthy" if config_valid else "unhealthy", - "database_isolation": "healthy" if isolation_valid else "unhealthy", - }, - "database": {"name": db_status["mongodb"]["database"], "connected": db_status["mongodb"]["connected"], "isolation_valid": isolation_valid}, - "configuration": config_status, - "metrics": {"check_duration_ms": db_status["check_duration_ms"], "timestamp": db_status["timestamp"]}, - "environment": os.getenv("ENVIRONMENT", "development"), - } +# Account +app.include_router(account_router.router, prefix="/ums/v1/account", tags=["Account Management"]) + +# Wallet +app.include_router(wallet_router.router, prefix="/ums/v1/wallet", tags=["Wallet Management"]) + +# Address +app.include_router(address_router.router, prefix="/ums/v1/addresses", tags=["Address Management"]) + +# Others (Pet, Guest, Favorite were under /api/v1/users in source) +app.include_router(pet_router.router, prefix="/ums/v1/users", tags=["Pet Management"]) +app.include_router(guest_router.router, prefix="/ums/v1/users", tags=["Guest Management"]) +app.include_router(favorite_router.router, prefix="/ums/v1/users", tags=["Favorites"]) + +# Reviews +app.include_router(review_router.router, prefix="/ums/v1/reviews", tags=["Reviews"]) + +@app.get("/") +def root(): + return {"message": "Insightfy BMS User Management Service is running"} +@app.get("/health") +def health(): + return {"status": "healthy", "service": "insightfy-bms-ms-ums", "timestamp": time.time()} diff --git a/app/core/cache_client.py b/app/core/cache_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a08ed588da633b43b637bd87f1c5416b4f0f55bf --- /dev/null +++ b/app/core/cache_client.py @@ -0,0 +1,79 @@ +import logging +from redis.asyncio import Redis +from redis.exceptions import RedisError, ConnectionError, AuthenticationError +from app.core.config import settings + +logger = logging.getLogger(__name__) + +# Parse host and port +CACHE_HOST, CACHE_PORT = settings.CACHE_URI.split(":") +CACHE_PORT = int(CACHE_PORT) + +async def create_redis_client(): + """Create Redis client with proper error handling and fallback""" + try: + # First try with authentication if password is provided + if settings.CACHE_K and settings.CACHE_K.strip(): + redis_client = Redis( + host=CACHE_HOST, + port=CACHE_PORT, + username="default", + password=settings.CACHE_K, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True + ) + # Test the connection + await redis_client.ping() + logger.info(f"Connected to Redis at {CACHE_HOST}:{CACHE_PORT} with authentication") + return redis_client + else: + # Try without authentication for local Redis + redis_client = Redis( + host=CACHE_HOST, + port=CACHE_PORT, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True + ) + # Test the connection + await redis_client.ping() + logger.info(f"Connected to Redis at {CACHE_HOST}:{CACHE_PORT} without authentication") + return redis_client + + except AuthenticationError as e: + logger.warning(f"Authentication failed for Redis: {e}") + # Try without authentication as fallback + try: + redis_client = Redis( + host=CACHE_HOST, + port=CACHE_PORT, + decode_responses=True, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True + ) + await redis_client.ping() + logger.info(f"Connected to Redis at {CACHE_HOST}:{CACHE_PORT} without authentication (fallback)") + return redis_client + except Exception as fallback_error: + logger.error(f"Redis fallback connection also failed: {fallback_error}") + raise + + except ConnectionError as e: + logger.error(f"Failed to connect to Redis at {CACHE_HOST}:{CACHE_PORT}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error connecting to Redis: {e}") + raise + +# Initialize Redis client +redis_client = None + +async def get_redis() -> Redis: + global redis_client + if redis_client is None: + redis_client = await create_redis_client() + return redis_client \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..336791cca498e8c2e3a26adac6957b68a988a237 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,70 @@ +from dotenv import load_dotenv +import os + +load_dotenv() + +class Settings: + # MongoDB + MONGO_URI: str = os.getenv("MONGO_URI") + DB_NAME: str = os.getenv("DB_NAME") + + # Redis + CACHE_URI: str = os.getenv("CACHE_URI") + CACHE_K: str = os.getenv("CACHE_K") + + # JWT (Unified across services) + # Prefer JWT_* envs; fall back to legacy names to ensure compatibility + JWT_SECRET_KEY: str = os.getenv("JWT_SECRET_KEY") or os.getenv("SECRET_KEY", "B00Kmyservice@7") + JWT_ALGORITHM: str = os.getenv("JWT_ALGORITHM") or os.getenv("ALGORITHM", "HS256") + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = int( + os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "43200")) + ) + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = int( + os.getenv("JWT_REFRESH_TOKEN_EXPIRE_DAYS", os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")) + ) + JWT_TEMP_TOKEN_EXPIRE_MINUTES: int = int( + os.getenv("JWT_TEMP_TOKEN_EXPIRE_MINUTES", os.getenv("TEMP_TOKEN_EXPIRE_MINUTES", "10")) + ) + JWT_REMEMBER_ME_EXPIRE_DAYS: int = int( + os.getenv("JWT_REMEMBER_ME_EXPIRE_DAYS", "30") # 30 days for remember me + ) + + # Backward compatibility: keep legacy attributes pointing to unified values + SECRET_KEY: str = JWT_SECRET_KEY + ALGORITHM: str = JWT_ALGORITHM + + # Twilio SMS + TWILIO_ACCOUNT_SID: str = os.getenv("TWILIO_ACCOUNT_SID") + TWILIO_AUTH_TOKEN: str = os.getenv("TWILIO_AUTH_TOKEN") + TWILIO_SMS_FROM: str = os.getenv("TWILIO_SMS_FROM") + + # SMTP Email + SMTP_HOST: str = os.getenv("SMTP_HOST") + SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587")) + SMTP_USER: str = os.getenv("SMTP_USER") + SMTP_PASS: str = os.getenv("SMTP_PASS") + SMTP_FROM: str = os.getenv("SMTP_FROM") + + # OAuth Providers + GOOGLE_CLIENT_ID: str = os.getenv("GOOGLE_CLIENT_ID") + APPLE_AUDIENCE: str = os.getenv("APPLE_AUDIENCE") + FACEBOOK_APP_ID: str = os.getenv("FACEBOOK_APP_ID") + FACEBOOK_APP_SECRET: str = os.getenv("FACEBOOK_APP_SECRET") + + # Local testing: bypass external OAuth verification when enabled + OAUTH_TEST_MODE: bool = os.getenv("OAUTH_TEST_MODE", "false").lower() == "true" + + # Security Settings + MAX_LOGIN_ATTEMPTS: int = int(os.getenv("MAX_LOGIN_ATTEMPTS", "5")) + ACCOUNT_LOCK_DURATION: int = int(os.getenv("ACCOUNT_LOCK_DURATION", "900")) # 15 minutes + OTP_VALIDITY_MINUTES: int = int(os.getenv("OTP_VALIDITY_MINUTES", "5")) + IP_RATE_LIMIT_MAX: int = int(os.getenv("IP_RATE_LIMIT_MAX", "10")) + IP_RATE_LIMIT_WINDOW: int = int(os.getenv("IP_RATE_LIMIT_WINDOW", "3600")) # 1 hour + + def __post_init__(self): + if not self.MONGO_URI or not self.DB_NAME: + raise ValueError("MongoDB URI or DB_NAME not configured.") + if not self.CACHE_URI or not self.CACHE_K: + raise ValueError("Redis URI or password (CACHE_K) not configured.") + +settings = Settings() \ No newline at end of file diff --git a/app/core/nosql_client.py b/app/core/nosql_client.py new file mode 100644 index 0000000000000000000000000000000000000000..7852cefe6f4abda4c9b35cdf01521a1a49cdbe73 --- /dev/null +++ b/app/core/nosql_client.py @@ -0,0 +1,10 @@ +import logging +from app.nosql import mongo_db, mongo_client + +logger = logging.getLogger(__name__) + +# Alias for backward compatibility with migrated code +db = mongo_db + +async def get_mongo_client(): + return mongo_client diff --git a/app/middleware/rate_limiter.py b/app/middleware/rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..032b83eef4d9c6f41431b5e988f0054f82cb8b5b --- /dev/null +++ b/app/middleware/rate_limiter.py @@ -0,0 +1,27 @@ +from fastapi import Request, HTTPException +from starlette.middleware.base import BaseHTTPMiddleware +import time +from collections import defaultdict, deque + +class RateLimitMiddleware(BaseHTTPMiddleware): + def __init__(self, app, calls: int = 100, period: int = 60): + super().__init__(app) + self.calls = calls + self.period = period + self.clients = defaultdict(deque) + + async def dispatch(self, request: Request, call_next): + client_ip = request.client.host + now = time.time() + + # Clean old requests + while self.clients[client_ip] and self.clients[client_ip][0] <= now - self.period: + self.clients[client_ip].popleft() + + # Check rate limit + if len(self.clients[client_ip]) >= self.calls: + raise HTTPException(status_code=429, detail="Rate limit exceeded") + + self.clients[client_ip].append(now) + response = await call_next(request) + return response \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/models/address_model.py b/app/models/address_model.py new file mode 100644 index 0000000000000000000000000000000000000000..447be1ec28ac69d7f36949c23392896cc7cb0442 --- /dev/null +++ b/app/models/address_model.py @@ -0,0 +1,243 @@ +from datetime import datetime +from typing import Optional, List, Dict, Any +from bson import ObjectId +import logging +import uuid + +from app.core.nosql_client import db + +logger = logging.getLogger(__name__) + +class AddressModel: + """Model for managing user delivery addresses embedded under customer documents""" + + @staticmethod + async def create_address(customer_id: str, address_data: Dict[str, Any]) -> Optional[str]: + """Create a new embedded address for a user inside customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + + address_id = str(uuid.uuid4()) + current_time = datetime.utcnow() + + address_doc = { + "address_id": address_id, # New field for address identification + "address_line_1": address_data.get("address_line_1"), + "address_line_2": address_data.get("address_line_2", ""), + "city": address_data.get("city"), + "state": address_data.get("state"), + "postal_code": address_data.get("postal_code"), + "country": address_data.get("country", "India"), + "address_type": address_data.get("address_type", "home"), # home, work, other + "is_default": address_data.get("is_default", False), + "landmark": address_data.get("landmark", ""), + "created_at": current_time, + "updated_at": current_time, + } + + # Fetch user doc + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + logger.error(f"User not found for customer_id {customer_id}") + return None + + addresses = user.get("addresses", []) + + # If setting default, unset any existing defaults + if address_doc.get("is_default"): + for a in addresses: + if a.get("is_default"): + a["is_default"] = False + a["updated_at"] = datetime.utcnow() + else: + # If this is the first address, set default + if len(addresses) == 0: + address_doc["is_default"] = True + + addresses.append(address_doc) + + update_result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"addresses": addresses}} + ) + + if update_result.modified_count == 0: + logger.error(f"Failed to insert embedded address for user {customer_id}") + return None + + logger.info(f"Created embedded address for user {customer_id}") + return address_doc["address_id"] # Return the address_id field instead of _id + + except Exception as e: + logger.error(f"Error creating embedded address for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def get_user_addresses(customer_id: str) -> List[Dict[str, Any]]: + """Get all embedded addresses for a user from customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return [] + + addresses = user.get("addresses", []) + # Sort by created_at desc and return as-is (no _id field) + addresses.sort(key=lambda x: x.get("created_at", datetime.utcnow()), reverse=True) + return addresses + + except Exception as e: + logger.error(f"Error getting embedded addresses for user {customer_id}: {str(e)}") + return [] + + @staticmethod + async def get_address_by_id(customer_id: str, address_id: str) -> Optional[Dict[str, Any]]: + """Get a specific embedded address by ID for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + + addresses = user.get("addresses", []) + + for a in addresses: + if a.get("address_id") == address_id: + a_copy = dict(a) + # Inject customer_id for backward-compat where used in router + a_copy["customer_id"] = customer_id + return a_copy + return None + + except Exception as e: + logger.error(f"Error getting embedded address {address_id} for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def update_address(customer_id: str, address_id: str, update_data: Dict[str, Any]) -> bool: + """Update an existing embedded address""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + addresses = user.get("addresses", []) + + updated = False + for a in addresses: + if a.get("address_id") == address_id: + allowed_fields = [ + "address_line_1", "address_line_2", "city", "state", "postal_code", + "country", "address_type", "is_default", "landmark" + ] + for field in allowed_fields: + if field in update_data: + a[field] = update_data[field] + a["updated_at"] = datetime.utcnow() + updated = True + break + + if not updated: + return False + + # If setting as default, unset other defaults + if update_data.get("is_default"): + for a in addresses: + if a.get("address_id") != address_id and a.get("is_default"): + a["is_default"] = False + a["updated_at"] = datetime.utcnow() + + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"addresses": addresses}} + ) + return result.modified_count > 0 + + except Exception as e: + logger.error(f"Error updating embedded address {address_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def delete_address(customer_id: str, address_id: str) -> bool: + """Delete an embedded address""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + addresses = user.get("addresses", []) + # Filter out by new domain id field 'address_id' + new_addresses = [a for a in addresses if a.get("address_id") != address_id] + + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"addresses": new_addresses}} + ) + + logger.info(f"Deleted embedded address {address_id} for user {customer_id}") + return result.modified_count > 0 + + except Exception as e: + logger.error(f"Error deleting embedded address {address_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def get_default_address(customer_id: str) -> Optional[Dict[str, Any]]: + """Get the default embedded address for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + + addresses = user.get("addresses", []) + for a in addresses: + if a.get("is_default"): + a_copy = dict(a) + a_copy["customer_id"] = customer_id + return a_copy + return None + + except Exception as e: + logger.error(f"Error getting default embedded address for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def set_default_address(customer_id: str, address_id: str) -> bool: + """Set an embedded address as default and unset others""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + addresses = user.get("addresses", []) + + changed = False + for a in addresses: + if a.get("address_id") == address_id: + if not a.get("is_default"): + a["is_default"] = True + a["updated_at"] = datetime.utcnow() + changed = True + else: + if a.get("is_default"): + a["is_default"] = False + a["updated_at"] = datetime.utcnow() + changed = True + + if not changed: + # Even if nothing changed, ensure persistence + pass + + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"addresses": addresses}} + ) + return result.modified_count >= 0 + + except Exception as e: + logger.error(f"Error setting default embedded address {address_id} for user {customer_id}: {str(e)}") + return False \ No newline at end of file diff --git a/app/models/favorite_model.py b/app/models/favorite_model.py new file mode 100644 index 0000000000000000000000000000000000000000..195df2576a47f1273134878e9523233c27025933 --- /dev/null +++ b/app/models/favorite_model.py @@ -0,0 +1,209 @@ +from fastapi import HTTPException +from app.core.nosql_client import db +import logging +from datetime import datetime +from typing import List, Optional, Dict, Any +from app.utils.db import prepare_for_db +import uuid +from app.models.user_model import BookMyServiceUserModel + +logger = logging.getLogger("favorite_model") + +class BookMyServiceFavoriteModel: + + @staticmethod + async def create_favorite( + customer_id: str, + favorite_data:dict + ): + """Create a new favorite merchant entry""" + logger.info(f"Creating favorite for customer {customer_id}, merchant {favorite_data['merchant_id']}") + + try: + # Check if favorite already exists + user = await BookMyServiceUserModel.collection.find_one({ + "customer_id": customer_id + }) + + if not user: + logger.error(f"User not found for customer_id {customer_id}") + return None + + favorites = user.get("favorites", []) + + # Create favorite document + favorite_doc = { + "merchant_id":favorite_data["merchant_id"], + "merchant_category": favorite_data["merchant_category"], + "merchant_name":favorite_data["merchant_name"], + "source": favorite_data[ "source"], + "added_at": datetime.utcnow(), + "notes": favorite_data.get("notes") + } + favorites.append(favorite_doc) + sanitized_favorite = prepare_for_db(favorites) + + update_res = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"favorites": sanitized_favorite}} + ) + if update_res.modified_count > 0: + logger.info(f"favourites created successfully: {favorite_data['merchant_id']} for user: {customer_id}") + else: + logger.info(f"favourites creation attempted with no modified_count for user: {customer_id}") + + return favorite_data["merchant_id"] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating favorite: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to create favorite") + + @staticmethod + async def delete_favorite(customer_id: str, merchant_id: str): + """Remove a merchant from favorites""" + logger.info(f"Deleting favorite for customer {customer_id}, merchant {merchant_id}") + + try: + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + favorites = user.get("favorites", []) + filterd_favorites = [favorite for favorite in favorites if favorite.get("merchant_id") != merchant_id] + if len(filterd_favorites) == len(favorites): + logger.warning(f"Embedded favorite not found for deletion: {merchant_id} (user {customer_id})") + return False + + # Sanitize for MongoDB before write + sanitized_filterd_favorites = prepare_for_db(filterd_favorites) + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"favorites": sanitized_filterd_favorites}} + ) + + if result.modified_count > 0: + logger.info(f"Embedded favorite deleted successfully: {merchant_id} (user {customer_id})") + return True + else: + logger.info(f"Embedded favorite deletion applied with no modified_count: {merchant_id} (user {customer_id})") + return True + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting favorite: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to delete favorite") + + @staticmethod + async def get_favorites( customer_id: str, limit:int=50 )-> List[Dict[str, Any]]: + """Get user's favorite merchants, optionally filtered by category""" + logger.info(f"Getting favorite merchant for customer {customer_id}") + + try: + # Build query + + # Check if favorite already exists + user = await BookMyServiceUserModel.collection.find_one({ + "customer_id": customer_id + }) + + if not user: + return [] + + + favorites = user.get("favorites", []) + favorites.sort(key=lambda x: x["added_at"], reverse=True) + favorite_data = favorites[:limit] + + # Get total count for pagination + + logger.info(f"Found {len(favorites)} favorites for customer {customer_id}") + + return { + "favorites": favorite_data, + "total_count": len(favorites), + "limit": limit + } + + except Exception as e: + logger.error(f"Error getting favorites: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to get favorites") + + @staticmethod + async def get_favorite(customer_id: str, merchant_id: str): + """Get a specific favorite entry""" + logger.info(f"Getting favorite for customer {customer_id}, merchant {merchant_id}") + + try: + + user = await BookMyServiceUserModel.collection.find_one({ + "customer_id": customer_id + }) + + if not user: + return [] + + favorites = user.get("favorites", []) + + + for favorite in favorites: + if favorite.get("merchant_id") == merchant_id: + return favorite + logger.warning(f"Embedded favorite merchant not found: {merchant_id} for user: {customer_id}") + return None + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting favorite: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to get favorite") + + @staticmethod + async def update_favorite_notes(customer_id: str, merchant_id: str, notes: str): + """Update notes for a favorite merchant""" + logger.info(f"Updating notes for customer {customer_id}, merchant {merchant_id}") + + try: + user = await BookMyServiceUserModel.collection.find_one({ + "customer_id": customer_id + }) + + if not user: + return False + + + result = await BookMyServiceUserModel.collection.update_one({"customer_id": customer_id, "favorites.merchant_id": merchant_id}, + {"$set": {"favorites.$.notes": notes}}) + + print("result.matched_count",result.matched_count) + if result.matched_count == 0 : + logger.warning(f"Favorite not found for customer {customer_id}, merchant {merchant_id}") + raise HTTPException(status_code=404, detail="Favorite not found") + + logger.info(f"Favorite merchant notes updated successfully") + return True + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating favorite merchant notes: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to update favorite notes") + + @staticmethod + async def is_favorite(customer_id: str, merchant_id: str) -> bool: + """Check if a merchant is in user's favorites""" + logger.info(f"Checking if merchant {merchant_id} is favorite for customer {customer_id}") + + try: + favorite = await BookMyServiceFavoriteModel.collection.find_one({ + "customer_id": customer_id, + "merchant_id": merchant_id + }) + + return favorite is not None + + except Exception as e: + logger.error(f"Error checking favorite status: {str(e)}", exc_info=True) + return False \ No newline at end of file diff --git a/app/models/guest_model.py b/app/models/guest_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ac4669560a1ff7d4c935c0cfebc41d5e086adf --- /dev/null +++ b/app/models/guest_model.py @@ -0,0 +1,290 @@ +from app.core.nosql_client import db +from datetime import datetime +from typing import List, Optional, Dict, Any +import uuid +import logging + +from app.utils.db import prepare_for_db + +logger = logging.getLogger(__name__) + +class GuestModel: + """Model for managing guest profiles embedded under customer documents""" + + @staticmethod + async def create_guest(customer_id: str, guest_data: dict) -> Optional[str]: + """Create a new embedded guest profile under a user in customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + logger.error(f"User not found for customer_id {customer_id}") + return None + + guest_id = str(uuid.uuid4()) + current_time = datetime.utcnow() + + guest_doc = { + "guest_id": guest_id, + "first_name": guest_data.get("first_name"), + "last_name": guest_data.get("last_name"), + "email": guest_data.get("email"), + "phone_number": guest_data.get("phone_number"), + "gender": getattr(guest_data.get("gender"), "value", guest_data.get("gender")), + "date_of_birth": guest_data.get("date_of_birth"), + "relationship": getattr(guest_data.get("relationship"), "value", guest_data.get("relationship")), + "notes": guest_data.get("notes"), + "is_default": guest_data.get("is_default", False), + "created_at": current_time, + "updated_at": current_time, + } + + guests = user.get("guests", []) + # Handle default semantics: if setting this guest as default, unset others. + if guest_doc.get("is_default"): + for existing in guests: + if existing.get("is_default"): + existing["is_default"] = False + existing["updated_at"] = current_time + else: + # If this is the first guest, make it default by default + if len(guests) == 0: + guest_doc["is_default"] = True + guests.append(guest_doc) + + # Sanitize for MongoDB (convert date to datetime, strip tzinfo, etc.) + sanitized_guests = prepare_for_db(guests) + update_res = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"guests": sanitized_guests}} + ) + + if update_res.modified_count > 0: + logger.info(f"Guest created successfully: {guest_id} for user: {customer_id}") + return guest_id + else: + logger.info(f"Guest creation attempted with no modified_count for user: {customer_id}") + return guest_id + + except Exception as e: + logger.error(f"Error creating embedded guest for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def get_user_guests(customer_id: str) -> List[Dict[str, Any]]: + """Get all embedded guests for a specific user from customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return [] + + guests = user.get("guests", []) + guests.sort(key=lambda x: x.get("created_at", datetime.utcnow()), reverse=True) + for g in guests: + g["customer_id"] = customer_id + logger.info(f"Retrieved {len(guests)} embedded guests for user: {customer_id}") + return guests + + except Exception as e: + logger.error(f"Error getting embedded guests for user {customer_id}: {str(e)}") + return [] + + @staticmethod + async def get_guest_by_id(customer_id: str, guest_id: str) -> Optional[Dict[str, Any]]: + """Get a specific embedded guest by ID for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + + guests = user.get("guests", []) + for g in guests: + if g.get("guest_id") == guest_id: + g_copy = dict(g) + g_copy["customer_id"] = customer_id + logger.info(f"Embedded guest found: {guest_id} for user: {customer_id}") + return g_copy + logger.warning(f"Embedded guest not found: {guest_id} for user: {customer_id}") + return None + + except Exception as e: + logger.error(f"Error getting embedded guest {guest_id} for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def update_guest(customer_id: str, guest_id: str, update_fields: Dict[str, Any]) -> bool: + """Update an embedded guest's information under a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + guests = user.get("guests", []) + updated = False + + for idx, g in enumerate(guests): + if g.get("guest_id") == guest_id: + normalized_updates: Dict[str, Any] = {} + for k, v in update_fields.items(): + if hasattr(v, "value"): + normalized_updates[k] = v.value + else: + normalized_updates[k] = v + + normalized_updates["updated_at"] = datetime.utcnow() + guests[idx] = {**g, **normalized_updates} + updated = True + # If is_default is being set to True, unset default for others + if update_fields.get("is_default"): + for j, other in enumerate(guests): + if other.get("guest_id") != guest_id and other.get("is_default"): + other["is_default"] = False + other["updated_at"] = datetime.utcnow() + guests[j] = other + break + + if not updated: + logger.warning(f"Embedded guest not found for update: {guest_id} (user {customer_id})") + return False + + # Sanitize for MongoDB before write + sanitized_guests = prepare_for_db(guests) + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"guests": sanitized_guests}} + ) + + if result.modified_count > 0: + logger.info(f"Embedded guest updated successfully: {guest_id} (user {customer_id})") + return True + else: + logger.info(f"Embedded guest update applied with no modified_count: {guest_id} (user {customer_id})") + return True + + except Exception as e: + logger.error(f"Error updating embedded guest {guest_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def delete_guest(customer_id: str, guest_id: str) -> bool: + """Delete an embedded guest profile under a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + guests = user.get("guests", []) + new_guests = [g for g in guests if g.get("guest_id") != guest_id] + if len(new_guests) == len(guests): + logger.warning(f"Embedded guest not found for deletion: {guest_id} (user {customer_id})") + return False + + # Sanitize for MongoDB before write + sanitized_new_guests = prepare_for_db(new_guests) + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"guests": sanitized_new_guests}} + ) + + if result.modified_count > 0: + logger.info(f"Embedded guest deleted successfully: {guest_id} (user {customer_id})") + return True + else: + logger.info(f"Embedded guest deletion applied with no modified_count: {guest_id} (user {customer_id})") + return True + + except Exception as e: + logger.error(f"Error deleting embedded guest {guest_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def get_guest_count_for_user(customer_id: str) -> int: + """Get the total number of embedded guests for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return 0 + return len(user.get("guests", [])) + except Exception as e: + logger.error(f"Error counting embedded guests for user {customer_id}: {str(e)}") + return 0 + + @staticmethod + async def check_guest_ownership(guest_id: str, customer_id: str) -> bool: + """ + Check if a guest belongs to a specific user. + + Args: + guest_id: ID of the guest + customer_id: ID of the user + + Returns: + True if guest belongs to user, False otherwise + """ + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + guests = user.get("guests", []) + return any(g.get("guest_id") == guest_id for g in guests) + except Exception as e: + logger.error(f"Error checking embedded guest ownership {guest_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def get_default_guest(customer_id: str) -> Optional[Dict[str, Any]]: + """Get the default guest for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + guests = user.get("guests", []) + for g in guests: + if g.get("is_default"): + g_copy = dict(g) + g_copy["customer_id"] = customer_id + return g_copy + return None + except Exception as e: + logger.error(f"Error getting default guest for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def set_default_guest(customer_id: str, guest_id: str) -> bool: + """Set a guest as default for a user, unsetting any existing default""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + guests = user.get("guests", []) + found = False + now = datetime.utcnow() + for g in guests: + if g.get("guest_id") == guest_id: + g["is_default"] = True + g["updated_at"] = now + found = True + else: + if g.get("is_default"): + g["is_default"] = False + g["updated_at"] = now + if not found: + return False + # Sanitize for MongoDB before write + sanitized_guests = prepare_for_db(guests) + res = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"guests": sanitized_guests}} + ) + return True + except Exception as e: + logger.error(f"Error setting default guest {guest_id} for user {customer_id}: {str(e)}") + return False \ No newline at end of file diff --git a/app/models/otp_model.py b/app/models/otp_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4b6c2ad470b7e4953fc97b14c4063bb257a0c8 --- /dev/null +++ b/app/models/otp_model.py @@ -0,0 +1,230 @@ +from fastapi import HTTPException +from app.core.cache_client import get_redis +from app.utils.sms_utils import send_sms_otp +from app.utils.email_utils import send_email_otp +from app.utils.common_utils import is_email +import logging + +logger = logging.getLogger("otp_model") + +class BookMyServiceOTPModel: + OTP_TTL = 300 # 5 minutes + RATE_LIMIT_MAX = 3 + RATE_LIMIT_WINDOW = 600 # 10 minutes + IP_RATE_LIMIT_MAX = 10 # Max 10 OTPs per IP per hour + IP_RATE_LIMIT_WINDOW = 3600 # 1 hour + FAILED_ATTEMPTS_MAX = 5 # Max 5 failed attempts before lock + FAILED_ATTEMPTS_WINDOW = 3600 # 1 hour + ACCOUNT_LOCK_DURATION = 1800 # 30 minutes + + @staticmethod + async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL, client_ip: str = None): + logger.info(f"Storing OTP for identifier: {identifier}, IP: {client_ip}") + + try: + redis = await get_redis() + logger.info(f"Redis connection established for OTP storage") + + # Check if account is locked + if await BookMyServiceOTPModel.is_account_locked(identifier): + logger.warning(f"Account locked for identifier: {identifier}") + raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts") + + # Rate limit: max 3 OTPs per identifier per 10 minutes + rate_key = f"otp_rate_limit:{identifier}" + logger.info(f"Checking rate limit with key: {rate_key}") + + attempts = await redis.incr(rate_key) + logger.info(f"Current OTP attempts for {identifier}: {attempts}") + + if attempts == 1: + await redis.expire(rate_key, BookMyServiceOTPModel.RATE_LIMIT_WINDOW) + logger.info(f"Set rate limit expiry for {identifier}") + elif attempts > BookMyServiceOTPModel.RATE_LIMIT_MAX: + logger.warning(f"Rate limit exceeded for {identifier}: {attempts} attempts") + raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.") + + # IP-based rate limiting + if client_ip: + ip_rate_key = f"otp_ip_rate_limit:{client_ip}" + ip_attempts = await redis.incr(ip_rate_key) + + if ip_attempts == 1: + await redis.expire(ip_rate_key, BookMyServiceOTPModel.IP_RATE_LIMIT_WINDOW) + elif ip_attempts > BookMyServiceOTPModel.IP_RATE_LIMIT_MAX: + logger.warning(f"IP rate limit exceeded for {client_ip}: {ip_attempts} attempts") + raise HTTPException(status_code=429, detail="Too many OTP requests from this IP address") + + # Store OTP + otp_key = f"bms_otp:{identifier}" + await redis.setex(otp_key, ttl, otp) + logger.info(f"OTP stored successfully for {identifier} with key: {otp_key}, TTL: {ttl}") + + except HTTPException as e: + logger.error(f"HTTP error storing OTP for {identifier}: {e.status_code} - {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error storing OTP for {identifier}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to store OTP") + ''' + # Send OTP via SMS, fallback to Email if identifier is email + try: + sid = send_sms_otp(phone, otp) + print(f"OTP {otp} sent to {phone} via SMS. SID: {sid}") + except Exception as sms_error: + print(f"⚠️ SMS failed: {sms_error}") + if is_email(identifier): + try: + await send_email_otp(identifier, otp) + print(f"✅ OTP {otp} sent to {identifier} via email fallback.") + except Exception as email_error: + raise HTTPException(status_code=500, detail=f"SMS and email both failed: {email_error}") + else: + raise HTTPException(status_code=500, detail="SMS failed and no email fallback available.") + ''' + @staticmethod + async def verify_otp(identifier: str, otp: str, client_ip: str = None): + logger.info(f"Verifying OTP for identifier: {identifier}, IP: {client_ip}") + logger.info(f"Provided OTP: {otp}") + + try: + redis = await get_redis() + logger.info("Redis connection established for OTP verification") + + # Check if account is locked + if await BookMyServiceOTPModel.is_account_locked(identifier): + logger.warning(f"Account locked for identifier: {identifier}") + raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts") + + key = f"bms_otp:{identifier}" + logger.info(f"Looking up OTP with key: {key}") + + stored = await redis.get(key) + logger.info(f"Stored OTP value: {stored}") + + if stored: + logger.info(f"OTP found in Redis. Comparing: provided='{otp}' vs stored='{stored}'") + if stored == otp: + logger.info(f"OTP verification successful for {identifier}") + await redis.delete(key) + # Clear failed attempts on successful verification + await BookMyServiceOTPModel.clear_failed_attempts(identifier) + logger.info(f"OTP deleted from Redis after successful verification") + return True + else: + logger.warning(f"OTP mismatch for {identifier}: provided='{otp}' vs stored='{stored}'") + # Track failed attempt + await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip) + return False + else: + logger.warning(f"No OTP found in Redis for identifier: {identifier} with key: {key}") + # Track failed attempt for expired/non-existent OTP + await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip) + return False + + except HTTPException as e: + logger.error(f"HTTP error verifying OTP for {identifier}: {e.status_code} - {e.detail}") + raise e + except Exception as e: + logger.error(f"Error verifying OTP for {identifier}: {str(e)}", exc_info=True) + return False + + @staticmethod + async def read_otp(identifier: str): + redis = await get_redis() + key = f"bms_otp:{identifier}" + otp = await redis.get(key) + if otp: + return otp + raise HTTPException(status_code=404, detail="OTP not found or expired") + + @staticmethod + async def track_failed_attempt(identifier: str, client_ip: str = None): + """Track failed OTP verification attempts""" + logger.info(f"Tracking failed attempt for identifier: {identifier}, IP: {client_ip}") + + try: + redis = await get_redis() + + # Track failed attempts for identifier + failed_key = f"failed_otp:{identifier}" + attempts = await redis.incr(failed_key) + + if attempts == 1: + await redis.expire(failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW) + + # Lock account if too many failed attempts + if attempts >= BookMyServiceOTPModel.FAILED_ATTEMPTS_MAX: + await BookMyServiceOTPModel.lock_account(identifier) + logger.warning(f"Account locked for {identifier} after {attempts} failed attempts") + + # Track IP-based failed attempts + if client_ip: + ip_failed_key = f"failed_otp_ip:{client_ip}" + ip_attempts = await redis.incr(ip_failed_key) + + if ip_attempts == 1: + await redis.expire(ip_failed_key, BookMyServiceOTPModel.FAILED_ATTEMPTS_WINDOW) + + logger.info(f"IP {client_ip} failed attempts: {ip_attempts}") + + except Exception as e: + logger.error(f"Error tracking failed attempt for {identifier}: {str(e)}", exc_info=True) + + @staticmethod + async def clear_failed_attempts(identifier: str): + """Clear failed attempts counter on successful verification""" + try: + redis = await get_redis() + failed_key = f"failed_otp:{identifier}" + await redis.delete(failed_key) + logger.info(f"Cleared failed attempts for {identifier}") + except Exception as e: + logger.error(f"Error clearing failed attempts for {identifier}: {str(e)}", exc_info=True) + + @staticmethod + async def lock_account(identifier: str): + """Lock account temporarily""" + try: + redis = await get_redis() + lock_key = f"account_locked:{identifier}" + await redis.setex(lock_key, BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION, "locked") + logger.info(f"Account locked for {identifier} for {BookMyServiceOTPModel.ACCOUNT_LOCK_DURATION} seconds") + except Exception as e: + logger.error(f"Error locking account for {identifier}: {str(e)}", exc_info=True) + + @staticmethod + async def is_account_locked(identifier: str) -> bool: + """Check if account is currently locked""" + try: + redis = await get_redis() + lock_key = f"account_locked:{identifier}" + locked = await redis.get(lock_key) + return locked is not None + except Exception as e: + logger.error(f"Error checking account lock for {identifier}: {str(e)}", exc_info=True) + return False + + @staticmethod + async def get_rate_limit_count(rate_key: str) -> int: + """Get current rate limit count for a key""" + try: + redis = await get_redis() + count = await redis.get(rate_key) + return int(count) if count else 0 + except Exception as e: + logger.error(f"Error getting rate limit count for {rate_key}: {str(e)}", exc_info=True) + return 0 + + @staticmethod + async def increment_rate_limit(rate_key: str, window: int) -> int: + """Increment rate limit counter with expiry""" + try: + redis = await get_redis() + count = await redis.incr(rate_key) + if count == 1: + await redis.expire(rate_key, window) + return count + except Exception as e: + logger.error(f"Error incrementing rate limit for {rate_key}: {str(e)}", exc_info=True) + return 0 \ No newline at end of file diff --git a/app/models/pet_model.py b/app/models/pet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ac73ce7d9e7a2c5d9e941e25bd49f897c4505599 --- /dev/null +++ b/app/models/pet_model.py @@ -0,0 +1,270 @@ +from app.core.nosql_client import db +from datetime import datetime +from typing import List, Optional, Dict, Any +import uuid +import logging + +from app.utils.db import prepare_for_db + +logger = logging.getLogger(__name__) + +class PetModel: + """Model for managing pet profiles embedded under customer documents""" + + @staticmethod + async def create_pet( + customer_id: str, + pet_data: dict + ) -> Optional[str]: + """Create a new embedded pet profile under a user in customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + logger.error(f"User not found for customer_id {customer_id}") + return None + + pet_id = str(uuid.uuid4()) + current_time = datetime.utcnow() + + pet_doc = { + "pet_id": pet_id, + "pet_name": pet_data.get('pet_name'), + "species": getattr(pet_data.get('species'), 'value', pet_data.get('species')), + "breed": pet_data.get('breed'), + "date_of_birth": pet_data.get('date_of_birth'), + "age": pet_data.get('age'), + "weight": pet_data.get('weight'), + "gender": getattr(pet_data.get('gender'), 'value', pet_data.get('gender')), + "temperament": getattr(pet_data.get('temperament'), 'value', pet_data.get('temperament')), + "health_notes": pet_data.get('health_notes'), + "is_vaccinated": pet_data.get('is_vaccinated'), + "pet_photo_url": pet_data.get('pet_photo_url'), + "is_default": pet_data.get('is_default', False), + "created_at": current_time, + "updated_at": current_time + } + + pets = user.get("pets", []) + if pet_doc.get("is_default"): + for existing in pets: + if existing.get("is_default"): + existing["is_default"] = False + existing["updated_at"] = current_time + else: + if len(pets) == 0: + pet_doc["is_default"] = True + pets.append(pet_doc) + + # Sanitize pets list for MongoDB (convert date to datetime, etc.) + sanitized_pets = prepare_for_db(pets) + update_res = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"pets": sanitized_pets}} + ) + + if update_res.modified_count > 0: + logger.info(f"Embedded pet created successfully: {pet_id} for user: {customer_id}") + return pet_id + else: + logger.info(f"Embedded pet creation attempted with no modified_count for user: {customer_id}") + return pet_id + + except Exception as e: + logger.error(f"Error creating embedded pet for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def get_user_pets(customer_id: str) -> List[Dict[str, Any]]: + """Get all embedded pets for a specific user from customers collection""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return [] + + pets = user.get("pets", []) + pets.sort(key=lambda x: x.get("created_at", datetime.utcnow()), reverse=True) + for p in pets: + p["customer_id"] = customer_id + logger.info(f"Retrieved {len(pets)} embedded pets for user: {customer_id}") + return pets + + except Exception as e: + logger.error(f"Error getting embedded pets for user {customer_id}: {str(e)}") + return [] + + @staticmethod + async def get_pet_by_id(customer_id: str, pet_id: str) -> Optional[Dict[str, Any]]: + """Get a specific embedded pet by ID for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + + pets = user.get("pets", []) + for p in pets: + if p.get("pet_id") == pet_id: + p_copy = dict(p) + p_copy["customer_id"] = customer_id + logger.info(f"Embedded pet found: {pet_id} for user: {customer_id}") + return p_copy + logger.warning(f"Embedded pet not found: {pet_id} for user: {customer_id}") + return None + + except Exception as e: + logger.error(f"Error getting embedded pet {pet_id} for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def update_pet(customer_id: str, pet_id: str, update_fields: Dict[str, Any]) -> bool: + """Update an embedded pet's information under a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + pets = user.get("pets", []) + updated = False + + for idx, p in enumerate(pets): + if p.get("pet_id") == pet_id: + normalized_updates: Dict[str, Any] = {} + for k, v in update_fields.items(): + if hasattr(v, "value"): + normalized_updates[k] = v.value + else: + normalized_updates[k] = v + + normalized_updates["updated_at"] = datetime.utcnow() + pets[idx] = {**p, **normalized_updates} + updated = True + if update_fields.get("is_default"): + for j, other in enumerate(pets): + if other.get("pet_id") != pet_id and other.get("is_default"): + other["is_default"] = False + other["updated_at"] = datetime.utcnow() + pets[j] = other + break + + if not updated: + logger.warning(f"Embedded pet not found for update: {pet_id} (user {customer_id})") + return False + + # Sanitize pets list for MongoDB + sanitized_pets = prepare_for_db(pets) + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"pets": sanitized_pets}} + ) + + if result.modified_count > 0: + logger.info(f"Embedded pet updated successfully: {pet_id} (user {customer_id})") + return True + else: + logger.info(f"Embedded pet update applied with no modified_count: {pet_id} (user {customer_id})") + return True + + except Exception as e: + logger.error(f"Error updating embedded pet {pet_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def delete_pet(customer_id: str, pet_id: str) -> bool: + """Delete an embedded pet profile under a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + + pets = user.get("pets", []) + new_pets = [p for p in pets if p.get("pet_id") != pet_id] + if len(new_pets) == len(pets): + logger.warning(f"Embedded pet not found for deletion: {pet_id} (user {customer_id})") + return False + + # Sanitize pets list for MongoDB + sanitized_new_pets = prepare_for_db(new_pets) + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"pets": sanitized_new_pets}} + ) + + if result.modified_count > 0: + logger.info(f"Embedded pet deleted successfully: {pet_id} (user {customer_id})") + return True + else: + logger.info(f"Embedded pet deletion applied with no modified_count: {pet_id} (user {customer_id})") + return True + + except Exception as e: + logger.error(f"Error deleting embedded pet {pet_id} for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def get_pet_count_for_user(customer_id: str) -> int: + """Get the total number of embedded pets for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return 0 + return len(user.get("pets", [])) + except Exception as e: + logger.error(f"Error counting embedded pets for user {customer_id}: {str(e)}") + return 0 + + @staticmethod + async def get_default_pet(customer_id: str) -> Optional[Dict[str, Any]]: + """Get the default pet for a user""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return None + pets = user.get("pets", []) + for p in pets: + if p.get("is_default"): + p_copy = dict(p) + p_copy["customer_id"] = customer_id + return p_copy + return None + except Exception as e: + logger.error(f"Error getting default pet for user {customer_id}: {str(e)}") + return None + + @staticmethod + async def set_default_pet(customer_id: str, pet_id: str) -> bool: + """Set a pet as default for a user, unsetting any existing default""" + try: + from app.models.user_model import BookMyServiceUserModel + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if not user: + return False + pets = user.get("pets", []) + found = False + now = datetime.utcnow() + for p in pets: + if p.get("pet_id") == pet_id: + p["is_default"] = True + p["updated_at"] = now + found = True + else: + if p.get("is_default"): + p["is_default"] = False + p["updated_at"] = now + if not found: + return False + # Sanitize pets list for MongoDB + sanitized_pets = prepare_for_db(pets) + res = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": {"pets": sanitized_pets}} + ) + return True + except Exception as e: + logger.error(f"Error setting default pet {pet_id} for user {customer_id}: {str(e)}") + return False \ No newline at end of file diff --git a/app/models/refresh_token_model.py b/app/models/refresh_token_model.py new file mode 100644 index 0000000000000000000000000000000000000000..675fd2fd7454465150e7272805b8a54d96a647b3 --- /dev/null +++ b/app/models/refresh_token_model.py @@ -0,0 +1,282 @@ +from datetime import datetime, timedelta +from typing import Optional +import uuid +import logging +from app.core.nosql_client import db +from app.core.cache_client import get_redis + +logger = logging.getLogger("refresh_token_model") + +class RefreshTokenModel: + """Model for managing refresh tokens with rotation support""" + + collection = db["refresh_tokens"] + + # Token family tracking for rotation + TOKEN_FAMILY_TTL = 30 * 24 * 3600 # 30 days in seconds + + @staticmethod + async def create_token_family(customer_id: str, device_info: Optional[str] = None) -> str: + """Create a new token family for refresh token rotation""" + family_id = str(uuid.uuid4()) + + try: + redis = await get_redis() + family_key = f"token_family:{family_id}" + + family_data = { + "customer_id": customer_id, + "device_info": device_info, + "created_at": datetime.utcnow().isoformat(), + "rotation_count": 0 + } + + import json + await redis.setex(family_key, RefreshTokenModel.TOKEN_FAMILY_TTL, json.dumps(family_data)) + + logger.info(f"Created token family {family_id} for user {customer_id}") + return family_id + + except Exception as e: + logger.error(f"Error creating token family: {str(e)}", exc_info=True) + raise + + @staticmethod + async def get_token_family(family_id: str) -> Optional[dict]: + """Get token family data""" + try: + redis = await get_redis() + family_key = f"token_family:{family_id}" + + data = await redis.get(family_key) + if data: + import json + return json.loads(data) + return None + + except Exception as e: + logger.error(f"Error getting token family: {str(e)}", exc_info=True) + return None + + @staticmethod + async def increment_rotation_count(family_id: str) -> int: + """Increment rotation count for a token family""" + try: + redis = await get_redis() + family_key = f"token_family:{family_id}" + + family_data = await RefreshTokenModel.get_token_family(family_id) + if not family_data: + return 0 + + family_data["rotation_count"] = family_data.get("rotation_count", 0) + 1 + family_data["last_rotated"] = datetime.utcnow().isoformat() + + import json + ttl = await redis.ttl(family_key) + if ttl > 0: + await redis.setex(family_key, ttl, json.dumps(family_data)) + + logger.info(f"Incremented rotation count for family {family_id} to {family_data['rotation_count']}") + return family_data["rotation_count"] + + except Exception as e: + logger.error(f"Error incrementing rotation count: {str(e)}", exc_info=True) + return 0 + + @staticmethod + async def store_refresh_token( + token_id: str, + customer_id: str, + family_id: str, + expires_at: datetime, + remember_me: bool = False, + device_info: Optional[str] = None, + ip_address: Optional[str] = None + ): + """Store refresh token metadata""" + try: + token_doc = { + "token_id": token_id, + "customer_id": customer_id, + "family_id": family_id, + "expires_at": expires_at, + "remember_me": remember_me, + "device_info": device_info, + "ip_address": ip_address, + "created_at": datetime.utcnow(), + "revoked": False, + "used": False + } + + await RefreshTokenModel.collection.insert_one(token_doc) + logger.info(f"Stored refresh token {token_id} for user {customer_id}") + + except Exception as e: + logger.error(f"Error storing refresh token: {str(e)}", exc_info=True) + raise + + @staticmethod + async def mark_token_as_used(token_id: str) -> bool: + """Mark a refresh token as used (for rotation)""" + try: + result = await RefreshTokenModel.collection.update_one( + {"token_id": token_id}, + { + "$set": { + "used": True, + "used_at": datetime.utcnow() + } + } + ) + + if result.modified_count > 0: + logger.info(f"Marked token {token_id} as used") + return True + return False + + except Exception as e: + logger.error(f"Error marking token as used: {str(e)}", exc_info=True) + return False + + @staticmethod + async def is_token_valid(token_id: str) -> bool: + """Check if a refresh token is valid (not revoked or used)""" + try: + token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) + + if not token: + logger.warning(f"Token {token_id} not found") + return False + + if token.get("revoked"): + logger.warning(f"Token {token_id} is revoked") + return False + + if token.get("used"): + logger.warning(f"Token {token_id} already used - possible replay attack") + # Revoke entire token family on reuse attempt + await RefreshTokenModel.revoke_token_family(token.get("family_id")) + return False + + if token.get("expires_at") < datetime.utcnow(): + logger.warning(f"Token {token_id} is expired") + return False + + return True + + except Exception as e: + logger.error(f"Error checking token validity: {str(e)}", exc_info=True) + return False + + @staticmethod + async def get_token_metadata(token_id: str) -> Optional[dict]: + """Get refresh token metadata""" + try: + token = await RefreshTokenModel.collection.find_one({"token_id": token_id}) + return token + except Exception as e: + logger.error(f"Error getting token metadata: {str(e)}", exc_info=True) + return None + + @staticmethod + async def revoke_token(token_id: str) -> bool: + """Revoke a specific refresh token""" + try: + result = await RefreshTokenModel.collection.update_one( + {"token_id": token_id}, + { + "$set": { + "revoked": True, + "revoked_at": datetime.utcnow() + } + } + ) + + if result.modified_count > 0: + logger.info(f"Revoked token {token_id}") + return True + return False + + except Exception as e: + logger.error(f"Error revoking token: {str(e)}", exc_info=True) + return False + + @staticmethod + async def revoke_token_family(family_id: str) -> int: + """Revoke all tokens in a family (security breach detection)""" + try: + result = await RefreshTokenModel.collection.update_many( + {"family_id": family_id, "revoked": False}, + { + "$set": { + "revoked": True, + "revoked_at": datetime.utcnow(), + "revoke_reason": "token_reuse_detected" + } + } + ) + + # Also delete the family from Redis + redis = await get_redis() + await redis.delete(f"token_family:{family_id}") + + logger.warning(f"Revoked {result.modified_count} tokens in family {family_id}") + return result.modified_count + + except Exception as e: + logger.error(f"Error revoking token family: {str(e)}", exc_info=True) + return 0 + + @staticmethod + async def revoke_all_user_tokens(customer_id: str) -> int: + """Revoke all refresh tokens for a user (logout from all devices)""" + try: + result = await RefreshTokenModel.collection.update_many( + {"customer_id": customer_id, "revoked": False}, + { + "$set": { + "revoked": True, + "revoked_at": datetime.utcnow(), + "revoke_reason": "user_logout_all" + } + } + ) + + logger.info(f"Revoked {result.modified_count} tokens for user {customer_id}") + return result.modified_count + + except Exception as e: + logger.error(f"Error revoking all user tokens: {str(e)}", exc_info=True) + return 0 + + @staticmethod + async def get_active_sessions(customer_id: str) -> list: + """Get all active sessions (valid refresh tokens) for a user""" + try: + tokens = await RefreshTokenModel.collection.find({ + "customer_id": customer_id, + "revoked": False, + "expires_at": {"$gt": datetime.utcnow()} + }).to_list(length=100) + + return tokens + + except Exception as e: + logger.error(f"Error getting active sessions: {str(e)}", exc_info=True) + return [] + + @staticmethod + async def cleanup_expired_tokens(): + """Cleanup expired tokens (run periodically)""" + try: + result = await RefreshTokenModel.collection.delete_many({ + "expires_at": {"$lt": datetime.utcnow()} + }) + + logger.info(f"Cleaned up {result.deleted_count} expired tokens") + return result.deleted_count + + except Exception as e: + logger.error(f"Error cleaning up expired tokens: {str(e)}", exc_info=True) + return 0 diff --git a/app/models/review_model.py b/app/models/review_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ba55e08ed62dc91cfee03f0b8ec394f3b2edbbc3 --- /dev/null +++ b/app/models/review_model.py @@ -0,0 +1,45 @@ +from fastapi import HTTPException +from app.core.nosql_client import db +import logging +from datetime import datetime + +from app.utils.db import prepare_for_db + +logger = logging.getLogger(__name__) + +class ReviewModel: + collection = db["merchant_reviews"] + + @staticmethod + async def create_review( + review_data:dict + )->dict: + logger.info(f"Creating review for merchant {review_data['merchant_id']}") + + try: + + # Creating review document + review_doc = { + "merchant_id":review_data["merchant_id"], + "location_id": review_data["location_id"], + "user_name":review_data["user_name"], + "rating": review_data[ "rating"], + "review_text":review_data["review_text"], + "review_date": datetime.utcnow(), + "verified_purchase":review_data["verified_purchase"] + } + + sanitized_review = prepare_for_db(review_doc) + + result = await ReviewModel.collection.insert_one(sanitized_review) + + logger.info(f"review data inserted successfully: {result}") + + if result.inserted_id: + return review_doc + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error while adding review: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to add review details") \ No newline at end of file diff --git a/app/models/social_account_model.py b/app/models/social_account_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb9302b440ed9b43cf3a20729adde7f9f96f96f --- /dev/null +++ b/app/models/social_account_model.py @@ -0,0 +1,257 @@ +from fastapi import HTTPException +from app.core.nosql_client import db +from datetime import datetime +from typing import Optional, List, Dict, Any +import logging + +logger = logging.getLogger("social_account_model") + +class SocialAccountModel: + """Model for managing social login accounts and linking""" + + collection = db["social_accounts"] + + @staticmethod + async def create_social_account(customer_id: str, provider: str, provider_customer_id: str, user_info: Dict[str, Any]) -> str: + """Create a new social account record""" + try: + social_account = { + "customer_id": customer_id, + "provider": provider, + "provider_customer_id": provider_customer_id, + "email": user_info.get("email"), + "name": user_info.get("name"), + "picture": user_info.get("picture"), + "profile_data": user_info, + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow(), + "is_active": True, + "last_login": datetime.utcnow() + } + + result = await SocialAccountModel.collection.insert_one(social_account) + logger.info(f"Created social account for user {customer_id} with provider {provider}") + return str(result.inserted_id) + + except Exception as e: + logger.error(f"Error creating social account: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to create social account") + + @staticmethod + async def find_by_provider_and_customer_id(provider: str, provider_customer_id: str) -> Optional[Dict[str, Any]]: + """Find social account by provider and provider user ID""" + try: + account = await SocialAccountModel.collection.find_one({ + "provider": provider, + "provider_customer_id": provider_customer_id, + "is_active": True + }) + return account + except Exception as e: + logger.error(f"Error finding social account: {str(e)}", exc_info=True) + return None + + @staticmethod + async def find_by_customer_id(customer_id: str) -> List[Dict[str, Any]]: + """Find all social accounts for a user""" + try: + cursor = SocialAccountModel.collection.find({ + "customer_id": customer_id, + "is_active": True + }) + accounts = await cursor.to_list(length=None) + return accounts + except Exception as e: + logger.error(f"Error finding social accounts for user {customer_id}: {str(e)}", exc_info=True) + return [] + + @staticmethod + async def update_social_account(provider: str, provider_customer_id: str, user_info: Dict[str, Any]) -> bool: + """Update social account with latest user info""" + try: + update_data = { + "email": user_info.get("email"), + "name": user_info.get("name"), + "picture": user_info.get("picture"), + "profile_data": user_info, + "updated_at": datetime.utcnow(), + "last_login": datetime.utcnow() + } + + result = await SocialAccountModel.collection.update_one( + { + "provider": provider, + "provider_customer_id": provider_customer_id, + "is_active": True + }, + {"$set": update_data} + ) + + return result.modified_count > 0 + + except Exception as e: + logger.error(f"Error updating social account: {str(e)}", exc_info=True) + return False + + @staticmethod + async def link_social_account(customer_id: str, provider: str, provider_customer_id: str, user_info: Dict[str, Any]) -> bool: + """Link a social account to an existing user""" + try: + # Check if this social account is already linked to another user + existing_account = await SocialAccountModel.find_by_provider_and_customer_id(provider, provider_customer_id) + + if existing_account and existing_account["customer_id"] != customer_id: + logger.warning(f"Social account {provider}:{provider_customer_id} already linked to user {existing_account['customer_id']}") + raise HTTPException( + status_code=409, + detail=f"This {provider} account is already linked to another user" + ) + + if existing_account and existing_account["customer_id"] == customer_id: + # Update existing account + await SocialAccountModel.update_social_account(provider, provider_customer_id, user_info) + return True + + # Create new social account link + await SocialAccountModel.create_social_account(customer_id, provider, provider_customer_id, user_info) + return True + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error linking social account: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to link social account") + + @staticmethod + async def unlink_social_account(customer_id: str, provider: str) -> bool: + """Unlink a social account from a user""" + try: + result = await SocialAccountModel.collection.update_one( + { + "customer_id": customer_id, + "provider": provider, + "is_active": True + }, + { + "$set": { + "is_active": False, + "updated_at": datetime.utcnow() + } + } + ) + + if result.modified_count > 0: + logger.info(f"Unlinked {provider} account for user {customer_id}") + return True + else: + logger.warning(f"No active {provider} account found for user {customer_id}") + return False + + except Exception as e: + logger.error(f"Error unlinking social account: {str(e)}", exc_info=True) + return False + + @staticmethod + async def get_profile_picture(customer_id: str, preferred_provider: str = None) -> Optional[str]: + """Get user's profile picture from social accounts""" + try: + query = {"customer_id": customer_id, "is_active": True} + + # If preferred provider specified, try that first + if preferred_provider: + account = await SocialAccountModel.collection.find_one({ + **query, + "provider": preferred_provider, + "picture": {"$exists": True, "$ne": None} + }) + if account and account.get("picture"): + return account["picture"] + + # Otherwise, get any account with a profile picture + account = await SocialAccountModel.collection.find_one({ + **query, + "picture": {"$exists": True, "$ne": None} + }) + + return account.get("picture") if account else None + + except Exception as e: + logger.error(f"Error getting profile picture for user {customer_id}: {str(e)}", exc_info=True) + return None + + @staticmethod + async def get_social_account_summary(customer_id: str) -> Dict[str, Any]: + """Get summary of all linked social accounts for a user""" + try: + accounts = await SocialAccountModel.find_by_customer_id(customer_id) + + summary = { + "linked_accounts": [], + "total_accounts": len(accounts), + "profile_picture": None + } + + for account in accounts: + summary["linked_accounts"].append({ + "provider": account["provider"], + "email": account.get("email"), + "name": account.get("name"), + "linked_at": account["created_at"], + "last_login": account.get("last_login") + }) + + # Set profile picture if available + if not summary["profile_picture"] and account.get("picture"): + summary["profile_picture"] = account["picture"] + + return summary + + except Exception as e: + logger.error(f"Error getting social account summary for user {customer_id}: {str(e)}", exc_info=True) + return {"linked_accounts": [], "total_accounts": 0, "profile_picture": None} + + @staticmethod + async def merge_social_accounts(primary_customer_id: str, secondary_customer_id: str) -> bool: + """Merge social accounts from secondary user to primary user""" + try: + # Get all social accounts from secondary user + secondary_accounts = await SocialAccountModel.find_by_customer_id(secondary_customer_id) + + for account in secondary_accounts: + # Check if primary user already has this provider linked + existing = await SocialAccountModel.collection.find_one({ + "customer_id": primary_customer_id, + "provider": account["provider"], + "is_active": True + }) + + if not existing: + # Transfer the account to primary user + await SocialAccountModel.collection.update_one( + {"_id": account["_id"]}, + { + "$set": { + "customer_id": primary_customer_id, + "updated_at": datetime.utcnow() + } + } + ) + logger.info(f"Transferred {account['provider']} account from user {secondary_customer_id} to {primary_customer_id}") + else: + # Deactivate the secondary account + await SocialAccountModel.collection.update_one( + {"_id": account["_id"]}, + { + "$set": { + "is_active": False, + "updated_at": datetime.utcnow() + } + } + ) + logger.info(f"Deactivated duplicate {account['provider']} account for user {secondary_customer_id}") + + return True + + except Exception as e: + logger.error(f"Error merging social accounts: {str(e)}", exc_info=True) + return False \ No newline at end of file diff --git a/app/models/social_security_model.py b/app/models/social_security_model.py new file mode 100644 index 0000000000000000000000000000000000000000..97f48f526da3bb1fdfafa6ac16633c8bd1f211ec --- /dev/null +++ b/app/models/social_security_model.py @@ -0,0 +1,188 @@ +from datetime import datetime, timedelta +import logging +from app.core.cache_client import get_redis +from fastapi import HTTPException + +logger = logging.getLogger(__name__) + +class SocialSecurityModel: + """Model for handling social login security features""" + + # Rate limiting constants + OAUTH_RATE_LIMIT_MAX = 5 # Max OAuth attempts per IP per hour + OAUTH_RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds + + # Failed attempt tracking + OAUTH_FAILED_ATTEMPTS_MAX = 3 # Max failed OAuth attempts per IP + OAUTH_FAILED_ATTEMPTS_WINDOW = 1800 # 30 minutes + OAUTH_IP_LOCK_DURATION = 3600 # 1 hour lock for IP + + @staticmethod + async def check_oauth_rate_limit(client_ip: str, provider: str) -> bool: + """Check if OAuth rate limit is exceeded for IP and provider""" + if not client_ip: + return True # Allow if no IP provided + + try: + redis = await get_redis() + rate_key = f"oauth_rate:{client_ip}:{provider}" + + current_count = await redis.get(rate_key) + if current_count and int(current_count) >= SocialSecurityModel.OAUTH_RATE_LIMIT_MAX: + logger.warning(f"OAuth rate limit exceeded for IP {client_ip} and provider {provider}") + return False + + return True + + except Exception as e: + logger.error(f"Error checking OAuth rate limit: {str(e)}", exc_info=True) + return True # Allow on error to avoid blocking legitimate users + + @staticmethod + async def increment_oauth_rate_limit(client_ip: str, provider: str): + """Increment OAuth rate limit counter""" + if not client_ip: + return + + try: + redis = await get_redis() + rate_key = f"oauth_rate:{client_ip}:{provider}" + + count = await redis.incr(rate_key) + if count == 1: + await redis.expire(rate_key, SocialSecurityModel.OAUTH_RATE_LIMIT_WINDOW) + + logger.info(f"OAuth rate limit count for {client_ip}:{provider} = {count}") + + except Exception as e: + logger.error(f"Error incrementing OAuth rate limit: {str(e)}", exc_info=True) + + @staticmethod + async def track_oauth_failed_attempt(client_ip: str, provider: str): + """Track failed OAuth verification attempts""" + if not client_ip: + return + + try: + redis = await get_redis() + failed_key = f"oauth_failed:{client_ip}:{provider}" + + attempts = await redis.incr(failed_key) + if attempts == 1: + await redis.expire(failed_key, SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_WINDOW) + + # Lock IP if too many failed attempts + if attempts >= SocialSecurityModel.OAUTH_FAILED_ATTEMPTS_MAX: + await SocialSecurityModel.lock_oauth_ip(client_ip, provider) + logger.warning(f"IP {client_ip} locked for provider {provider} after {attempts} failed attempts") + + logger.info(f"OAuth failed attempts for {client_ip}:{provider} = {attempts}") + + except Exception as e: + logger.error(f"Error tracking OAuth failed attempt: {str(e)}", exc_info=True) + + @staticmethod + async def lock_oauth_ip(client_ip: str, provider: str): + """Lock IP for OAuth attempts on specific provider""" + try: + redis = await get_redis() + lock_key = f"oauth_ip_locked:{client_ip}:{provider}" + await redis.setex(lock_key, SocialSecurityModel.OAUTH_IP_LOCK_DURATION, "locked") + logger.info(f"IP {client_ip} locked for OAuth provider {provider}") + except Exception as e: + logger.error(f"Error locking OAuth IP: {str(e)}", exc_info=True) + + @staticmethod + async def is_oauth_ip_locked(client_ip: str, provider: str) -> bool: + """Check if IP is locked for OAuth attempts on specific provider""" + if not client_ip: + return False + + try: + redis = await get_redis() + lock_key = f"oauth_ip_locked:{client_ip}:{provider}" + locked = await redis.get(lock_key) + return locked is not None + except Exception as e: + logger.error(f"Error checking OAuth IP lock: {str(e)}", exc_info=True) + return False + + @staticmethod + async def clear_oauth_failed_attempts(client_ip: str, provider: str): + """Clear failed OAuth attempts on successful verification""" + if not client_ip: + return + + try: + redis = await get_redis() + failed_key = f"oauth_failed:{client_ip}:{provider}" + await redis.delete(failed_key) + logger.info(f"Cleared OAuth failed attempts for {client_ip}:{provider}") + except Exception as e: + logger.error(f"Error clearing OAuth failed attempts: {str(e)}", exc_info=True) + + @staticmethod + async def validate_oauth_token_format(token: str, provider: str) -> bool: + """Basic validation of OAuth token format""" + # In local test mode, accept any non-empty string to facilitate testing + try: + from app.core.config import settings + if getattr(settings, "OAUTH_TEST_MODE", False): + return bool(token) + except Exception: + pass + + if not token or not isinstance(token, str): + return False + + # Basic length and format checks + if provider == "google": + # Normalize optional Bearer prefix + t = token.strip() + if t.lower().startswith("bearer "): + t = t[7:] + # Accept Google ID tokens (JWT) + if len(t) > 100 and t.count('.') == 2: + return True + # Accept Google OAuth access tokens (commonly start with 'ya29.') and are shorter + if t.startswith("ya29.") or (len(t) >= 20 and len(t) <= 4096 and t.count('.') < 2): + return True + return False + elif provider == "apple": + # Apple ID tokens are also JWT format + return len(token) > 100 and token.count('.') == 2 + elif provider == "facebook": + # Facebook access tokens are typically shorter + return len(token) > 20 and len(token) < 500 + + return True # Allow unknown providers + + @staticmethod + async def log_oauth_attempt(client_ip: str, provider: str, success: bool, customer_id: str = None): + """Log OAuth authentication attempts for security monitoring""" + try: + redis = await get_redis() + log_key = f"oauth_log:{datetime.utcnow().strftime('%Y-%m-%d')}" + + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "ip": client_ip, + "provider": provider, + "success": success, + "customer_id": customer_id + } + + # Store as JSON string in Redis list + import json + await redis.lpush(log_key, json.dumps(log_entry)) + + # Keep only last 1000 entries per day + await redis.ltrim(log_key, 0, 999) + + # Set expiry for 30 days + await redis.expire(log_key, 30 * 24 * 3600) + + logger.info(f"OAuth attempt logged: {provider} from {client_ip} - {'success' if success else 'failed'}") + + except Exception as e: + logger.error(f"Error logging OAuth attempt: {str(e)}", exc_info=True) \ No newline at end of file diff --git a/app/models/user_model.py b/app/models/user_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cac894597bd890d6d7abe4acd554e1c90e4bf9 --- /dev/null +++ b/app/models/user_model.py @@ -0,0 +1,159 @@ +from fastapi import HTTPException +from app.core.nosql_client import db +from app.utils.common_utils import is_email, is_phone, validate_identifier # Updated imports +from app.schemas.user_schema import UserRegisterRequest +import logging + +logger = logging.getLogger("user_model") + +class BookMyServiceUserModel: + collection = db["customers"] + + @staticmethod + async def find_by_email(email: str): + logger.info(f"Searching for user by email: {email}") + try: + user = await BookMyServiceUserModel.collection.find_one({"email": email}) + if user: + logger.info(f"User found by email: {email}") + else: + logger.info(f"No user found with email: {email}") + return user + except Exception as e: + logger.error(f"Error finding user by email {email}: {str(e)}", exc_info=True) + return None + + @staticmethod + async def find_by_phone(phone: str): + logger.info(f"Searching for user by phone: {phone}") + try: + user = await BookMyServiceUserModel.collection.find_one({"phone": phone}) + if user: + logger.info(f"User found by phone: {phone}") + else: + logger.info(f"No user found with phone: {phone}") + return user + except Exception as e: + logger.error(f"Error finding user by phone {phone}: {str(e)}", exc_info=True) + return None + + @staticmethod + async def find_by_mobile(mobile: str): + """Legacy method for backward compatibility - redirects to find_by_phone""" + logger.info(f"Legacy find_by_mobile called, redirecting to find_by_phone for: {mobile}") + return await BookMyServiceUserModel.find_by_phone(mobile) + + @staticmethod + async def find_by_identifier(identifier: str): + logger.info(f"Finding user by identifier: {identifier}") + + try: + # Validate and determine identifier type + identifier_type = validate_identifier(identifier) + logger.info(f"Identifier type determined: {identifier_type}") + + if identifier_type == "email": + logger.info(f"Searching by email for identifier: {identifier}") + user = await BookMyServiceUserModel.find_by_email(identifier) + elif identifier_type == "phone": + logger.info(f"Searching by phone for identifier: {identifier}") + user = await BookMyServiceUserModel.find_by_phone(identifier) + else: + logger.error(f"Invalid identifier type: {identifier_type}") + raise HTTPException(status_code=400, detail="Invalid identifier format") + + if not user: + logger.warning(f"User not found with identifier: {identifier}") + raise HTTPException(status_code=404, detail="User not found with this email or phone") + + logger.info(f"User found successfully for identifier: {identifier}") + logger.info(f"User data keys: {list(user.keys()) if user else 'None'}") + return user + + except ValueError as ve: + logger.error(f"Validation error for identifier {identifier}: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException as e: + logger.error(f"HTTP error finding user by identifier {identifier}: {e.status_code} - {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error finding user by identifier {identifier}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to find user") + + @staticmethod + async def exists_by_email_or_phone(email: str = None, phone: str = None) -> bool: + """Check if user exists by email or phone""" + query_conditions = [] + + if email: + query_conditions.append({"email": email}) + if phone: + query_conditions.append({"phone": phone}) + + if not query_conditions: + return False + + query = {"$or": query_conditions} if len(query_conditions) > 1 else query_conditions[0] + + result = await BookMyServiceUserModel.collection.find_one(query) + return result is not None + + @staticmethod + async def create(user_data: UserRegisterRequest): + user_dict = user_data.dict() + result = await BookMyServiceUserModel.collection.insert_one(user_dict) + return result.inserted_id + + @staticmethod + async def update_by_identifier(identifier: str, update_fields: dict): + try: + identifier_type = validate_identifier(identifier) + + if identifier_type == "email": + query = {"email": identifier} + elif identifier_type == "phone": + query = {"phone": identifier} + else: + raise HTTPException(status_code=400, detail="Invalid identifier format") + + result = await BookMyServiceUserModel.collection.update_one(query, {"$set": update_fields}) + if result.matched_count == 0: + raise HTTPException(status_code=404, detail="User not found") + return result.modified_count > 0 + + except ValueError as ve: + logger.error(f"Validation error for identifier {identifier}: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + + @staticmethod + async def update_profile(customer_id: str, update_fields: dict): + """Update user profile by customer_id""" + try: + from datetime import datetime + + # Add updated_at timestamp + update_fields["updated_at"] = datetime.utcnow() + + result = await BookMyServiceUserModel.collection.update_one( + {"customer_id": customer_id}, + {"$set": update_fields} + ) + + if result.matched_count == 0: + raise HTTPException(status_code=404, detail="User not found") + + return result.modified_count > 0 + + except Exception as e: + logger.error(f"Error updating profile for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to update profile") + + @staticmethod + async def find_by_id(customer_id: str): + """Find user by customer_id""" + try: + user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + return user + except Exception as e: + logger.error(f"Error finding user by ID {customer_id}: {str(e)}") + return None \ No newline at end of file diff --git a/app/models/wallet_model.py b/app/models/wallet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0148329b8278d8ad4c0d26281600b91ff7224b --- /dev/null +++ b/app/models/wallet_model.py @@ -0,0 +1,158 @@ +from datetime import datetime +from typing import Optional, List, Dict, Any +from bson import ObjectId +import logging + +from app.core.nosql_client import db + +logger = logging.getLogger(__name__) + +class WalletModel: + """Model for managing user wallet operations""" + + wallet_collection = db["user_wallets"] + transaction_collection = db["wallet_transactions"] + + @staticmethod + async def get_wallet_balance(customer_id: str) -> float: + """Get current wallet balance for a user""" + try: + wallet = await WalletModel.wallet_collection.find_one({"customer_id": customer_id}) + if wallet: + return wallet.get("balance", 0.0) + else: + # Create wallet if doesn't exist + await WalletModel.create_wallet(customer_id) + return 0.0 + except Exception as e: + logger.error(f"Error getting wallet balance for user {customer_id}: {str(e)}") + return 0.0 + + @staticmethod + async def create_wallet(customer_id: str, initial_balance: float = 0.0) -> bool: + """Create a new wallet for a user""" + try: + wallet_doc = { + "customer_id": customer_id, + "balance": initial_balance, + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow() + } + + result = await WalletModel.wallet_collection.insert_one(wallet_doc) + logger.info(f"Created wallet for user {customer_id} with balance {initial_balance}") + return result.inserted_id is not None + except Exception as e: + logger.error(f"Error creating wallet for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def update_balance(customer_id: str, amount: float, transaction_type: str, + description: str = "", reference_id: str = None) -> bool: + """Update wallet balance and create transaction record""" + try: + # Get current balance + current_balance = await WalletModel.get_wallet_balance(customer_id) + + # Calculate new balance + if transaction_type in ["credit", "refund", "cashback"]: + new_balance = current_balance + amount + elif transaction_type in ["debit", "payment", "withdrawal"]: + if current_balance < amount: + logger.warning(f"Insufficient balance for user {customer_id}. Current: {current_balance}, Required: {amount}") + return False + new_balance = current_balance - amount + else: + logger.error(f"Invalid transaction type: {transaction_type}") + return False + + # Update wallet balance + update_result = await WalletModel.wallet_collection.update_one( + {"customer_id": customer_id}, + { + "$set": { + "balance": new_balance, + "updated_at": datetime.utcnow() + } + }, + upsert=True + ) + + # Create transaction record + transaction_doc = { + "customer_id": customer_id, + "amount": amount, + "transaction_type": transaction_type, + "description": description, + "reference_id": reference_id, + "balance_before": current_balance, + "balance_after": new_balance, + "timestamp": datetime.utcnow(), + "status": "completed" + } + + await WalletModel.transaction_collection.insert_one(transaction_doc) + + logger.info(f"Updated wallet for user {customer_id}: {transaction_type} of {amount}, new balance: {new_balance}") + return True + + except Exception as e: + logger.error(f"Error updating wallet balance for user {customer_id}: {str(e)}") + return False + + @staticmethod + async def get_transaction_history(customer_id: str, page: int = 1, per_page: int = 20) -> Dict[str, Any]: + """Get paginated transaction history for a user""" + try: + skip = (page - 1) * per_page + + # Get transactions with pagination + cursor = WalletModel.transaction_collection.find( + {"customer_id": customer_id} + ).sort("timestamp", -1).skip(skip).limit(per_page) + + transactions = [] + async for transaction in cursor: + # Convert ObjectId to string for JSON serialization + transaction["_id"] = str(transaction["_id"]) + transactions.append(transaction) + + # Get total count + total_count = await WalletModel.transaction_collection.count_documents({"customer_id": customer_id}) + + return { + "transactions": transactions, + "total_count": total_count, + "page": page, + "per_page": per_page, + "total_pages": (total_count + per_page - 1) // per_page + } + + except Exception as e: + logger.error(f"Error getting transaction history for user {customer_id}: {str(e)}") + return { + "transactions": [], + "total_count": 0, + "page": page, + "per_page": per_page, + "total_pages": 0 + } + + @staticmethod + async def get_wallet_summary(customer_id: str) -> Dict[str, Any]: + """Get wallet summary including balance and recent transactions""" + try: + balance = await WalletModel.get_wallet_balance(customer_id) + recent_transactions = await WalletModel.get_transaction_history(customer_id, page=1, per_page=5) + + return { + "balance": balance, + "recent_transactions": recent_transactions["transactions"] + } + + except Exception as e: + logger.error(f"Error getting wallet summary for user {customer_id}: {str(e)}") + return { + "balance": 0.0, + "recent_transactions": [] + } \ No newline at end of file diff --git a/app/routers/__init__.py b/app/routers/__init__.py index 4b901e2427b549cc28ecfd8ad21840ed12cdb304..d21fc87fd08f4cdb9aad99f2c433c65d15da531f 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -1,2 +1,11 @@ -__all__ = ["user"] - +__all__ = [ + "user_router", + "profile_router", + "account_router", + "wallet_router", + "address_router", + "pet_router", + "guest_router", + "favorite_router", + "review_router" +] diff --git a/app/routers/account_router.py b/app/routers/account_router.py new file mode 100644 index 0000000000000000000000000000000000000000..5d028693fd8cb4c752e0aee049dc4a3d188ba127 --- /dev/null +++ b/app/routers/account_router.py @@ -0,0 +1,218 @@ +from fastapi import APIRouter, Depends, HTTPException, Request, Query +from fastapi.security import HTTPBearer +from typing import List, Optional +from datetime import datetime, timedelta +import logging + +from app.schemas.user_schema import ( + LinkSocialAccountRequest, UnlinkSocialAccountRequest, + SocialAccountSummary, LoginHistoryResponse, SecuritySettingsResponse, + TokenResponse +) +from app.services.account_service import AccountService +from app.utils.jwt import decode_token + +# Configure logging +logger = logging.getLogger(__name__) + +router = APIRouter() +security = HTTPBearer() + +def get_current_user(token: str = Depends(security)): + """Extract user ID from JWT token""" + try: + payload = decode_token(token.credentials) + customer_id = payload.get("sub") + if not customer_id: + raise HTTPException(status_code=401, detail="Invalid token") + return customer_id + except Exception as e: + logger.error(f"Token validation error: {str(e)}") + raise HTTPException(status_code=401, detail="Invalid or expired token") + +def get_client_ip(request: Request) -> str: + """Extract client IP from request""" + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + return forwarded_for.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + return request.client.host if request.client else "unknown" + +@router.get("/social-accounts", response_model=SocialAccountSummary) +async def get_social_accounts(customer_id: str = Depends(get_current_user)): + """Get all linked social accounts for the current user""" + try: + account_service = AccountService() + summary = await account_service.get_social_account_summary(customer_id) + return summary + except Exception as e: + logger.error(f"Error fetching social accounts for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to fetch social accounts") + +@router.post("/link-social-account", response_model=dict) +async def link_social_account( + request: LinkSocialAccountRequest, + req: Request, + customer_id: str = Depends(get_current_user) +): + """Link a new social account to the current user""" + try: + client_ip = get_client_ip(req) + account_service = AccountService() + + result = await account_service.link_social_account( + customer_id=customer_id, + provider=request.provider, + token=request.token, + client_ip=client_ip + ) + + return {"message": f"Successfully linked {request.provider} account", "result": result} + except ValueError as e: + logger.warning(f"Invalid link request for user {customer_id}: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error linking social account for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to link social account") + +@router.delete("/unlink-social-account", response_model=dict) +async def unlink_social_account( + request: UnlinkSocialAccountRequest, + customer_id: str = Depends(get_current_user) +): + """Unlink a social account from the current user""" + try: + account_service = AccountService() + + result = await account_service.unlink_social_account( + customer_id=customer_id, + provider=request.provider + ) + + return {"message": f"Successfully unlinked {request.provider} account", "result": result} + except ValueError as e: + logger.warning(f"Invalid unlink request for user {customer_id}: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error unlinking social account for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to unlink social account") + +@router.get("/login-history", response_model=LoginHistoryResponse) +async def get_login_history( + page: int = Query(1, ge=1, description="Page number"), + per_page: int = Query(10, ge=1, le=50, description="Items per page"), + days: int = Query(30, ge=1, le=365, description="Number of days to look back"), + customer_id: str = Depends(get_current_user) +): + """Get login history for the current user""" + try: + account_service = AccountService() + + history = await account_service.get_login_history( + customer_id=customer_id, + page=page, + per_page=per_page, + days=days + ) + + return history + except Exception as e: + logger.error(f"Error fetching login history for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to fetch login history") + +@router.get("/security-settings", response_model=SecuritySettingsResponse) +async def get_security_settings(customer_id: str = Depends(get_current_user)): + """Get security settings and status for the current user""" + try: + account_service = AccountService() + + settings = await account_service.get_security_settings(customer_id) + + return settings + except Exception as e: + logger.error(f"Error fetching security settings for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to fetch security settings") + +@router.post("/merge-accounts", response_model=dict) +async def merge_social_accounts( + target_customer_id: str, + req: Request, + customer_id: str = Depends(get_current_user) +): + """Merge social accounts from another user (admin function or user-initiated)""" + try: + # For security, only allow users to merge their own accounts or implement admin check + if customer_id != target_customer_id: + # In a real implementation, you'd check if the current user is an admin + # or if they have proper authorization to merge accounts + raise HTTPException(status_code=403, detail="Insufficient permissions") + + client_ip = get_client_ip(req) + account_service = AccountService() + + result = await account_service.merge_social_accounts( + primary_customer_id=customer_id, + secondary_customer_id=target_customer_id, + client_ip=client_ip + ) + + return {"message": "Successfully merged social accounts", "result": result} + except ValueError as e: + logger.warning(f"Invalid merge request for user {customer_id}: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error merging social accounts for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to merge social accounts") + +@router.delete("/revoke-all-sessions", response_model=dict) +async def revoke_all_sessions( + req: Request, + customer_id: str = Depends(get_current_user) +): + """Revoke all active sessions for security purposes""" + try: + client_ip = get_client_ip(req) + account_service = AccountService() + + result = await account_service.revoke_all_sessions(customer_id, client_ip) + + return {"message": "All sessions have been revoked", "result": result} + except Exception as e: + logger.error(f"Error revoking sessions for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to revoke sessions") + +@router.get("/trusted-devices", response_model=dict) +async def get_trusted_devices(customer_id: str = Depends(get_current_user)): + """Get list of trusted devices for the current user""" + try: + account_service = AccountService() + + devices = await account_service.get_trusted_devices(customer_id) + + return {"devices": devices} + except Exception as e: + logger.error(f"Error fetching trusted devices for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to fetch trusted devices") + +@router.delete("/trusted-devices/{device_id}", response_model=dict) +async def remove_trusted_device( + device_id: str, + customer_id: str = Depends(get_current_user) +): + """Remove a trusted device""" + try: + account_service = AccountService() + + result = await account_service.remove_trusted_device(customer_id, device_id) + + return {"message": "Trusted device removed successfully", "result": result} + except ValueError as e: + logger.warning(f"Invalid device removal request for user {customer_id}: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error removing trusted device for user {customer_id}: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to remove trusted device") \ No newline at end of file diff --git a/app/routers/address_router.py b/app/routers/address_router.py new file mode 100644 index 0000000000000000000000000000000000000000..43358662f09b1389c35b99b06d2d957861dfcfd6 --- /dev/null +++ b/app/routers/address_router.py @@ -0,0 +1,327 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from typing import List +import logging + +from app.utils.jwt import get_current_customer_id +from app.models.address_model import AddressModel +from app.schemas.address_schema import ( + AddressCreateRequest, AddressUpdateRequest, AddressResponse, + AddressListResponse, SetDefaultAddressRequest, AddressOperationResponse +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/", response_model=AddressListResponse) +async def get_user_addresses(current_customer_id: str = Depends(get_current_customer_id)): + """ + Get all delivery addresses for the current user. + + This endpoint is JWT protected and requires a valid Bearer token. + """ + try: + logger.info(f"Get addresses request for user: {current_customer_id}") + + addresses = await AddressModel.get_user_addresses(current_customer_id) + + address_responses = [] + for addr in addresses: + address_responses.append(AddressResponse( + address_id=addr["address_id"], # Use the new address_id field + address_type=addr["address_type"], + address_line_1=addr["address_line_1"], + address_line_2=addr.get("address_line_2", ""), + city=addr["city"], + state=addr["state"], + postal_code=addr["postal_code"], + country=addr.get("country", "India"), + landmark=addr.get("landmark", ""), + is_default=addr.get("is_default", False), + created_at=addr.get("created_at"), + updated_at=addr.get("updated_at") + )) + + if not address_responses: + return AddressListResponse( + success=False, + message="Failed to retrieve address" + ) + + else: + return AddressListResponse( + success=True, + message="Addresses retrieved successfully", + addresses=address_responses, + total_count=len(address_responses) + ) + + + except Exception as e: + logger.error(f"Error getting addresses for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve addresses" + ) + +@router.post("/", response_model=AddressOperationResponse) +async def create_address( + address_data: AddressCreateRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Create a new delivery address for the current user. + """ + try: + logger.info(f"Create address request for user: {current_customer_id}") + + # Check if user already has 5 addresses (limit) + existing_addresses = await AddressModel.get_user_addresses(current_customer_id) + if len(existing_addresses) >= 5: + raise HTTPException( + status_code=400, + detail="Maximum of 5 addresses allowed per user" + ) + + # If this is the first address, make it default + is_default = len(existing_addresses) == 0 or address_data.is_default + + address_id = await AddressModel.create_address(current_customer_id, address_data.dict()) + + + if address_id: + # Get the created address + created_address = await AddressModel.get_address_by_id(current_customer_id,address_id) + + address_response = AddressResponse( + address_id=created_address["address_id"], # Use the new address_id field + address_type=created_address["address_type"], + address_line_1=created_address["address_line_1"], + address_line_2=created_address.get("address_line_2", ""), + city=created_address["city"], + state=created_address["state"], + postal_code=created_address["postal_code"], + country=created_address.get("country", "India"), + is_default=created_address.get("is_default", False), + landmark=created_address.get("landmark", ""), + created_at=created_address.get("created_at"), + updated_at=created_address.get("updated_at") + ) + return AddressOperationResponse( + success=True, + message="Address created successfully", + address=address_response + ) + else: + return AddressOperationResponse( + success=False, + message="Failed to create address" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating address for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create address" + ) + +@router.put("/{address_id}", response_model=AddressOperationResponse) +async def update_address( + address_id: str, + address_data: AddressUpdateRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Update an existing delivery address. + """ + try: + logger.info(f"Update address request for user: {current_customer_id}, address: {address_id}") + + # Check if address exists and belongs to user + existing_address = await AddressModel.get_address_by_id(current_customer_id,address_id) + if not existing_address: + raise HTTPException(status_code=404, detail="Address not found") + + if existing_address["customer_id"] != current_customer_id: + raise HTTPException(status_code=403, detail="Access denied") + + # Prepare update fields + update_fields = {} + + if address_data.address_type is not None: + update_fields["address_type"] = address_data.address_type + if address_data.address_line_1 is not None: + update_fields["address_line_1"] = address_data.address_line_1 + if address_data.address_line_2 is not None: + update_fields["address_line_2"] = address_data.address_line_2 + if address_data.city is not None: + update_fields["city"] = address_data.city + if address_data.state is not None: + update_fields["state"] = address_data.state + if address_data.is_default is not None: + update_fields['is_default']=address_data.is_default + if address_data.landmark is not None: + update_fields["landmark"] = address_data.landmark + if address_data.postal_code is not None: + update_fields["postal_code"] = address_data.postal_code + if address_data.country is not None: + update_fields["country"] = address_data.country + if not update_fields: + raise HTTPException(status_code=400, detail="No fields to update") + + success = await AddressModel.update_address(current_customer_id,address_id, update_fields) + + if success: + # Get updated address + updated_address = await AddressModel.get_address_by_id(current_customer_id,address_id) + + address_response = AddressResponse( + address_id=updated_address["address_id"], # Use the new address_id field + address_type=updated_address["address_type"], + address_line_1=updated_address["address_line_1"], + address_line_2=updated_address.get("address_line_2", ""), + city=updated_address["city"], + state=updated_address["state"], + postal_code=updated_address["postal_code"], + country=updated_address.get("country", "India"), + landmark=updated_address.get("landmark", ""), + is_default=updated_address.get("is_default", False), + created_at=updated_address.get("created_at"), + updated_at=updated_address.get("updated_at") + ) + + return AddressOperationResponse( + success=True, + message="Address updated successfully", + address=address_response + ) + else: + return AddressOperationResponse( + success=False, + message="Failed to update address" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating address {address_id} for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update address" + ) + +@router.delete("/{address_id}", response_model=AddressOperationResponse) +async def delete_address( + address_id: str, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Delete a delivery address. + """ + try: + logger.info(f"Delete address request for user: {current_customer_id}, address: {address_id}") + + # Check if address exists and belongs to user + existing_address = await AddressModel.get_address_by_id(current_customer_id,address_id) + if not existing_address: + raise HTTPException(status_code=404, detail="Address not found") + + if existing_address["customer_id"] != current_customer_id: + raise HTTPException(status_code=403, detail="Access denied") + + # Check if this is the default address + if existing_address.get("is_default", False): + # Get other addresses to potentially set a new default + user_addresses = await AddressModel.get_user_addresses(current_customer_id) + # Compare by new domain id field 'address_id' + other_addresses = [addr for addr in user_addresses if addr.get("address_id") != address_id] + + if other_addresses: + # Set the first other address as default + await AddressModel.set_default_address(current_customer_id, other_addresses[0]["address_id"]) + + success = await AddressModel.delete_address(current_customer_id,address_id) + + if success: + return AddressOperationResponse( + success=True, + message="Address deleted successfully" + ) + else: + return AddressOperationResponse( + success=False, + message="Failed to delete address" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting address {address_id} for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete address" + ) + +@router.post("/set-default", response_model=AddressOperationResponse) +async def set_default_address( + request: SetDefaultAddressRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Set an address as the default delivery address. + """ + try: + logger.info(f"Set default address request for user: {current_customer_id}, address: {request.address_id}") + + # Check if address exists and belongs to user + existing_address = await AddressModel.get_address_by_id(current_customer_id,request.address_id) + if not existing_address: + raise HTTPException(status_code=404, detail="Address not found") + + if existing_address["customer_id"] != current_customer_id: + raise HTTPException(status_code=403, detail="Access denied") + + success = await AddressModel.set_default_address(current_customer_id, request.address_id) + + if success: + # Get updated address + updated_address = await AddressModel.get_address_by_id(current_customer_id,request.address_id) + + address_response = AddressResponse( + address_id=request.address_id, + address_type=updated_address["address_type"], + contact_name=updated_address.get("contact_name", ""), + contact_phone=updated_address.get("contact_phone", ""), + address_line_1=updated_address["address_line_1"], + address_line_2=updated_address.get("address_line_2", ""), + city=updated_address["city"], + state=updated_address["state"], + postal_code=updated_address["postal_code"], + country=updated_address.get("country", "India"), + landmark=updated_address.get("landmark", ""), + is_default=updated_address.get("is_default", False), + created_at=updated_address.get("created_at"), + updated_at=updated_address.get("updated_at") + ) + + return AddressOperationResponse( + success=True, + message="Default address set successfully", + address=address_response + ) + else: + return AddressOperationResponse( + success=False, + message="Failed to set default address" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error setting default address for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to set default address" + ) \ No newline at end of file diff --git a/app/routers/favorite_router.py b/app/routers/favorite_router.py new file mode 100644 index 0000000000000000000000000000000000000000..8160afe6cdf5e1f737f289c3471a9010cb95578b --- /dev/null +++ b/app/routers/favorite_router.py @@ -0,0 +1,205 @@ +from fastapi import APIRouter, Depends, HTTPException, Query, status +from typing import Optional +from app.schemas.favorite_schema import ( + FavoriteCreateRequest, + FavoriteUpdateRequest, + FavoriteResponse, + FavoritesListResponse, + FavoriteStatusResponse, + FavoriteSuccessResponse, + FavoriteDataResponse +) +from app.services.favorite_service import FavoriteService +from app.services.user_service import UserService +from app.utils.jwt import get_current_user +import logging + +logger = logging.getLogger("favorite_router") + +router = APIRouter( + prefix="/favorites", + tags=["favorites"], + responses={404: {"description": "Not found"}}, +) + +@router.post( + "", + response_model=FavoriteSuccessResponse, + status_code=status.HTTP_201_CREATED, + summary="Add merchant to favorites", + description="Add a merchant to the user's list of favorite merchants" +) +async def add_favorite( + favorite_data: FavoriteCreateRequest, + current_user: dict = Depends(get_current_user) +): + """ + Add a merchant to user's favorites. + + - **merchant_id**: ID of the merchant to favorite + - **merchant_category**: Category of the merchant (salon, spa, pet_spa, etc.) + - **merchant_name**: Name of the merchant for quick display + - **notes**: Optional note about this favorite + """ + try: + return await FavoriteService.add_favorite( + customer_id=current_user["sub"], + favorite_data=favorite_data + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in add_favorite endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + +@router.delete( + "/{merchant_id}", + response_model=FavoriteSuccessResponse, + status_code=status.HTTP_200_OK, + summary="Remove merchant from favorites", + description="Remove a merchant from the user's list of favorite merchants" +) +async def remove_favorite( + merchant_id: str, + current_user: dict = Depends(get_current_user) +): + """ + Remove a merchant from user's favorites. + + - **merchant_id**: ID of the merchant to remove from favorites + """ + try: + + customer_id=current_user.get("sub") + + # Check if favorite merchant already exists and belongs to user + existing_favorite = await FavoriteService.get_favorite_details(customer_id, merchant_id) + + if not existing_favorite.success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="favorite not found" + ) + + return await FavoriteService.remove_favorite( + customer_id, + merchant_id + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in remove_favorite endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + +@router.get( + "", + response_model=FavoritesListResponse, + status_code=status.HTTP_200_OK, + summary="List favorite merchants", + description="Get user's list of favorite merchants, optionally filtered by category" +) +async def list_favorites( + limit: int = Query(50, ge=1, le=100, description="Maximum number of items to return"), + current_user: dict = Depends(get_current_user) +): + """ + Get user's favorite merchants. + + - **category**: Optional filter by merchant category + - **skip**: Number of items to skip (for pagination) + - **limit**: Maximum number of items to return (max 100) + """ + try: + return await FavoriteService.get_favorites( + customer_id=current_user["sub"], + limit=limit + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in list_favorites endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + +@router.get( + "/{merchant_id}/status", + response_model=FavoriteStatusResponse, + status_code=status.HTTP_200_OK, + summary="Check favorite status", + description="Check if a merchant is in user's favorites" +) +async def check_favorite_status( + merchant_id: str, + current_user: dict = Depends(get_current_user) +): + """ + Check if a merchant is in user's favorites. + + - **merchant_id**: ID of the merchant to check + """ + try: + return await FavoriteService.check_favorite_status( + customer_id=current_user["sub"], + merchant_id=merchant_id + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in check_favorite_status endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + +@router.get( + "/{merchant_id}", + response_model=FavoriteDataResponse, + status_code=status.HTTP_200_OK, + summary="Get favorite details", + description="Get detailed information about a specific favorite merchant" +) +async def get_favorite_details( + merchant_id: str, + current_user: dict = Depends(get_current_user) +): + """ + Get detailed information about a specific favorite merchant. + + - **merchant_id**: ID of the merchant + """ + try: + return await FavoriteService.get_favorite_details( + customer_id=current_user["sub"], + merchant_id=merchant_id + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in get_favorite_details endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + +@router.patch( + "/{merchant_id}/notes", + response_model=FavoriteSuccessResponse, + status_code=status.HTTP_200_OK, + summary="Update favorite notes", + description="Update the notes for a favorite merchant" +) +async def update_favorite_notes( + merchant_id: str, + notes_data: FavoriteUpdateRequest, + current_user: dict = Depends(get_current_user) +): + """ + Update the notes for a favorite merchant. + + - **merchant_id**: ID of the merchant + - **notes**: Updated note about this favorite + """ + try: + return await FavoriteService.update_favorite_notes( + customer_id=current_user["sub"], + merchant_id=merchant_id, + notes_data=notes_data + ) + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error in update_favorite_notes endpoint: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") \ No newline at end of file diff --git a/app/routers/guest_router.py b/app/routers/guest_router.py new file mode 100644 index 0000000000000000000000000000000000000000..6c12e015784ef38b25bf60a23bcfee94d689d200 --- /dev/null +++ b/app/routers/guest_router.py @@ -0,0 +1,309 @@ +from fastapi import APIRouter, HTTPException, Depends, status +from fastapi.security import HTTPBearer +from app.models.guest_model import GuestModel +from app.schemas.guest_schema import ( + GuestCreateRequest, + GuestUpdateRequest, + GuestResponse, + GuestListResponse, + GuestDeleteResponse, + SetDefaultGuestRequest +) +from app.utils.jwt import verify_token +from typing import Dict, Any +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() +security = HTTPBearer() + +async def get_current_user(token: str = Depends(security)) -> Dict[str, Any]: + """ + Dependency to get current authenticated user from JWT token + """ + try: + payload = verify_token(token.credentials) + if not payload: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token" + ) + return payload + except Exception as e: + logger.error(f"Token verification failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token" + ) + +@router.get("/guests", response_model=GuestListResponse) +async def get_user_guests( + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Get all guests for a specific user. + + - **customer_id**: ID of the user + - Returns list of guests with total count + """ + try: + # Verify user can only access their own guests + customer_id=current_user.get("sub") + + guests_data = await GuestModel.get_user_guests(customer_id) + + guests = [GuestResponse(**guest) for guest in guests_data] + + return GuestListResponse( + guests=guests, + total_count=len(guests) + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting guests for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve guests" + ) + +@router.post("/guests", response_model=GuestResponse, status_code=status.HTTP_201_CREATED) +async def create_guest( + guest_data: GuestCreateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Create a new guest profile for a user. + + - **customer_id**: ID of the user creating the guest profile + - **guest_data**: Guest information including name, contact details, etc. + - Returns the created guest profile + """ + try: + # Verify user can only create guests for themselves + customer_id=current_user.get("sub") + + # Create guest in database + + guest_id=await GuestModel.create_guest(customer_id,guest_data.dict()) + + if not guest_id: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create guest profile" + ) + + # Retrieve and return the created guest + created_guest = await GuestModel.get_guest_by_id(customer_id, guest_id) + if not created_guest: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Guest created but failed to retrieve" + ) + + return GuestResponse(**created_guest) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating guest for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create guest profile" + ) + +@router.get("/guests/default", response_model=GuestResponse) +async def get_default_guest( + current_user: Dict[str, Any] = Depends(get_current_user) +): + """Get the default guest for the current user""" + try: + customer_id = current_user.get("sub") + default_guest = await GuestModel.get_default_guest(customer_id) + if not default_guest: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Default guest not set" + ) + return GuestResponse(**default_guest) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting default guest for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve default guest" + ) + +@router.post("/guests/set-default", response_model=GuestResponse) +async def set_default_guest( + req: SetDefaultGuestRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """Set a guest as default for the current user""" + try: + customer_id = current_user.get("sub") + # Verify guest exists and belongs to user + existing_guest = await GuestModel.get_guest_by_id(customer_id, req.guest_id) + if not existing_guest: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Guest not found" + ) + if existing_guest.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This guest doesn't belong to you." + ) + success = await GuestModel.set_default_guest(customer_id, req.guest_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to set default guest" + ) + updated_guest = await GuestModel.get_guest_by_id(customer_id, req.guest_id) + if not updated_guest: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Default set but failed to retrieve guest" + ) + return GuestResponse(**updated_guest) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error setting default guest {req.guest_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to set default guest" + ) +@router.put("/guests/{guest_id}", response_model=GuestResponse) +async def update_guest( + guest_id: str, + guest_data: GuestUpdateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Update an existing guest profile. + + - **customer_id**: ID of the user who owns the guest profile + - **guest_id**: ID of the guest to update + - **guest_data**: Updated guest information + - Returns the updated guest profile + """ + try: + # Verify user can only update their own guests + customer_id=current_user.get("sub") + + # Check if guest exists and belongs to user + existing_guest = await GuestModel.get_guest_by_id(customer_id, guest_id) + if not existing_guest: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Guest not found" + ) + + if existing_guest.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This guest doesn't belong to you." + ) + + # Prepare update fields (only include non-None values) + update_fields = {} + for field, value in guest_data.dict(exclude_unset=True).items(): + if value is not None: + if hasattr(value, 'value'): # Handle enum values + update_fields[field] = value.value + else: + update_fields[field] = value + + if not update_fields: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No valid fields provided for update" + ) + + # Update guest in database + success = await GuestModel.update_guest(customer_id, guest_id, update_fields) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update guest profile" + ) + + # Retrieve and return updated guest + updated_guest = await GuestModel.get_guest_by_id(customer_id, guest_id) + if not updated_guest: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Guest updated but failed to retrieve" + ) + + return GuestResponse(**updated_guest) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating guest {guest_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update guest profile" + ) + +@router.delete("/guests/{guest_id}", response_model=GuestDeleteResponse) +async def delete_guest( + guest_id: str, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Delete a guest profile. + + - **customer_id**: ID of the user who owns the guest profile + - **guest_id**: ID of the guest to delete + - Returns confirmation of deletion + """ + try: + # Verify user can only delete their own guests + customer_id=current_user.get("sub") + + # Check if guest exists and belongs to user + existing_guest = await GuestModel.get_guest_by_id(customer_id, guest_id) + if not existing_guest: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Guest not found" + ) + + if existing_guest.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This guest doesn't belong to you." + ) + + # Delete guest from database + success = await GuestModel.delete_guest(customer_id, guest_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete guest profile" + ) + + guest_name = existing_guest.get('first_name', 'Guest') + if existing_guest.get('last_name'): + guest_name += f" {existing_guest.get('last_name')}" + + return GuestDeleteResponse( + message=f"Guest '{guest_name}' has been successfully deleted", + guest_id=guest_id + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting guest {guest_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete guest profile" + ) \ No newline at end of file diff --git a/app/routers/pet_router.py b/app/routers/pet_router.py new file mode 100644 index 0000000000000000000000000000000000000000..753143444053b7d379f2dcf89e0435022aa1c3aa --- /dev/null +++ b/app/routers/pet_router.py @@ -0,0 +1,305 @@ +from fastapi import APIRouter, HTTPException, Depends, status +from fastapi.security import HTTPBearer +from app.models.pet_model import PetModel +from app.schemas.pet_schema import ( + PetCreateRequest, + PetUpdateRequest, + PetResponse, + PetListResponse, + PetDeleteResponse, + SetDefaultPetRequest +) +from app.utils.jwt import verify_token +from typing import Dict, Any +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() +security = HTTPBearer() + +async def get_current_user(token: str = Depends(security)) -> Dict[str, Any]: + """ + Dependency to get current authenticated user from JWT token + """ + try: + payload = verify_token(token.credentials) + if not payload: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token" + ) + return payload + except Exception as e: + logger.error(f"Token verification failed: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token" + ) + +@router.get("/pets", response_model=PetListResponse) +async def get_user_pets( + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Get all pets for a specific user. + + - **customer_id**: ID of the pet owner + - Returns list of pets with total count + """ + try: + # Verify user can only access their own pets + customer_id=current_user.get("sub") + + pets_data = await PetModel.get_user_pets(customer_id) + + pets = [PetResponse(**pet) for pet in pets_data] + + return PetListResponse( + pets=pets, + total_count=len(pets) + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting pets for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve pets" + ) + +@router.post("/pets", response_model=PetResponse, status_code=status.HTTP_201_CREATED) +async def create_pet( + pet_data: PetCreateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Create a new pet profile for a user. + + - **customer_id**: ID of the pet owner + - **pet_data**: Pet information including name, species, breed, etc. + - Returns the created pet profile + """ + try: + # Verify user can only create pets for themselves + customer_id=current_user.get("sub") + + + pet_id = await PetModel.create_pet(customer_id,pet_data.dict()) + # Create pet in database + + if not pet_id: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create pet profile" + ) + + # Retrieve and return the created pet + created_pet = await PetModel.get_pet_by_id(customer_id, pet_id) + if not created_pet: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Pet created but failed to retrieve" + ) + + return PetResponse(**created_pet) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating pet for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create pet profile" + ) + +@router.get("/pets/default", response_model=PetResponse) +async def get_default_pet( + current_user: Dict[str, Any] = Depends(get_current_user) +): + """Get the default pet for the current user""" + try: + customer_id = current_user.get("sub") + default_pet = await PetModel.get_default_pet(customer_id) + if not default_pet: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Default pet not set" + ) + return PetResponse(**default_pet) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting default pet for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve default pet" + ) + +@router.post("/pets/set-default", response_model=PetResponse) +async def set_default_pet( + req: SetDefaultPetRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """Set a pet as default for the current user""" + try: + customer_id = current_user.get("sub") + # Verify pet exists and belongs to user + existing_pet = await PetModel.get_pet_by_id(customer_id, req.pet_id) + if not existing_pet: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Pet not found" + ) + if existing_pet.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This pet doesn't belong to you." + ) + success = await PetModel.set_default_pet(customer_id, req.pet_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to set default pet" + ) + updated_pet = await PetModel.get_pet_by_id(customer_id, req.pet_id) + if not updated_pet: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Default set but failed to retrieve pet" + ) + return PetResponse(**updated_pet) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error setting default pet {req.pet_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to set default pet" + ) +@router.put("/pets/{pet_id}", response_model=PetResponse) +async def update_pet( + pet_id: str, + pet_data: PetUpdateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Update an existing pet profile. + + - **customer_id**: ID of the pet owner + - **pet_id**: ID of the pet to update + - **pet_data**: Updated pet information + - Returns the updated pet profile + """ + try: + # Verify user can only update their own pets + customer_id=current_user.get("sub") + + # Check if pet exists and belongs to user + existing_pet = await PetModel.get_pet_by_id(customer_id, pet_id) + if not existing_pet: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Pet not found" + ) + + if existing_pet.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This pet doesn't belong to you." + ) + + # Prepare update fields (only include non-None values) + update_fields = {} + for field, value in pet_data.dict(exclude_unset=True).items(): + if value is not None: + if hasattr(value, 'value'): # Handle enum values + update_fields[field] = value.value + else: + update_fields[field] = value + + if not update_fields: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No valid fields provided for update" + ) + + # Update pet in database + success = await PetModel.update_pet(customer_id, pet_id, update_fields) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update pet profile" + ) + + # Retrieve and return updated pet + updated_pet = await PetModel.get_pet_by_id(customer_id, pet_id) + if not updated_pet: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Pet updated but failed to retrieve" + ) + + return PetResponse(**updated_pet) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating pet {pet_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update pet profile" + ) + +@router.delete("/pets/{pet_id}", response_model=PetDeleteResponse) +async def delete_pet( + pet_id: str, + current_user: Dict[str, Any] = Depends(get_current_user) +): + """ + Delete a pet profile. + + - **customer_id**: ID of the pet owner + - **pet_id**: ID of the pet to delete + - Returns confirmation of deletion + """ + try: + # Verify user can only delete their own pets + customer_id=current_user.get("sub") + + # Check if pet exists and belongs to user + existing_pet = await PetModel.get_pet_by_id(customer_id, pet_id) + if not existing_pet: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Pet not found" + ) + + if existing_pet.get("customer_id") != customer_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. This pet doesn't belong to you." + ) + + # Delete pet from database + success = await PetModel.delete_pet(customer_id, pet_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete pet profile" + ) + + return PetDeleteResponse( + message=f"Pet '{existing_pet.get('pet_name')}' has been successfully deleted", + pet_id=pet_id + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting pet {pet_id} for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete pet profile" + ) \ No newline at end of file diff --git a/app/routers/profile_router.py b/app/routers/profile_router.py new file mode 100644 index 0000000000000000000000000000000000000000..c3494e6fd8f97774fc8fff746bb22f2d4a70f81a --- /dev/null +++ b/app/routers/profile_router.py @@ -0,0 +1,200 @@ + +## bookmyservice-ums/app/routers/profile_router.py + +from fastapi import APIRouter, Depends, HTTPException, status +from typing import Dict, Any +import logging + +from app.utils.jwt import get_current_customer_id +from app.services.profile_service import profile_service +from app.services.wallet_service import WalletService +from app.models.user_model import BookMyServiceUserModel +from app.models.address_model import AddressModel +from app.schemas.profile_schema import ( + ProfileUpdateRequest, ProfileResponse, ProfileOperationResponse, + PersonalDetailsResponse, WalletDisplayResponse, ProfileDashboardResponse +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/me", response_model=Dict[str, Any]) +async def get_profile(current_customer_id: str = Depends(get_current_customer_id)): + """ + Get current user's profile from customers collection. + + This endpoint is JWT protected and requires a valid Bearer token. + + Args: + current_customer_id (str): User ID extracted from JWT token + + Returns: + Dict[str, Any]: Customer profile data + + Raises: + HTTPException: 401 if token is invalid, 404 if profile not found + """ + try: + logger.info(f"Profile request for user: {current_customer_id}") + + # Fetch customer profile using the service + profile_data = await profile_service.get_customer_profile(current_customer_id) + + return { + "success": True, + "message": "Profile retrieved successfully", + "data": profile_data + } + + except HTTPException: + # Re-raise HTTP exceptions from service + raise + except Exception as e: + logger.error(f"Unexpected error in get_profile: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error" + ) + +@router.get("/dashboard", response_model=ProfileDashboardResponse) +async def get_profile_dashboard(current_customer_id: str = Depends(get_current_customer_id)): + """ + Get complete profile dashboard with personal details, wallet, and address info. + + This endpoint matches the screenshot requirements showing: + - Personal details (name, email, phone, DOB) + - Wallet balance + - Address management info + """ + try: + logger.info(f"Dashboard request for user: {current_customer_id}") + + # Get user profile + user = await BookMyServiceUserModel.find_by_id(current_customer_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Parse name into first and last name + name_parts = user.get("name", "").split(" ", 1) + first_name = name_parts[0] if name_parts else "" + last_name = name_parts[1] if len(name_parts) > 1 else "" + + # Get wallet balance + wallet_balance = await WalletService.get_wallet_balance(current_customer_id) + + # Get address count and default address status + addresses = await AddressModel.get_user_addresses(current_customer_id) + address_count = len(addresses) + has_default_address = any(addr.get("is_default", False) for addr in addresses) + + # Build response + personal_details = PersonalDetailsResponse( + first_name=first_name, + last_name=last_name, + email=user.get("email", ""), + phone=user.get("phone", ""), + date_of_birth=user.get("date_of_birth") + ) + + wallet_display = WalletDisplayResponse( + balance=wallet_balance.balance, + formatted_balance=wallet_balance.formatted_balance, + currency=wallet_balance.currency + ) + + return ProfileDashboardResponse( + personal_details=personal_details, + wallet=wallet_display, + address_count=address_count, + has_default_address=has_default_address + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting profile dashboard for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error" + ) + +@router.put("/update", response_model=ProfileOperationResponse) +async def update_profile( + profile_data: ProfileUpdateRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Update user profile information including personal details and DOB. + """ + try: + logger.info(f"Profile update request for user: {current_customer_id}") + + # Prepare update fields + update_fields = {} + + if profile_data.name is not None: + update_fields["name"] = profile_data.name + + if profile_data.email is not None: + # Check if email is already used by another user + existing_user = await BookMyServiceUserModel.find_by_email(str(profile_data.email)) + if existing_user and existing_user.get("customer_id") != current_customer_id: + raise HTTPException(status_code=409, detail="Email already in use by another account") + update_fields["email"] = str(profile_data.email) + + if profile_data.phone is not None: + # Check if phone is already used by another user + existing_user = await BookMyServiceUserModel.find_by_phone(profile_data.phone) + if existing_user and existing_user.get("customer_id") != current_customer_id: + raise HTTPException(status_code=409, detail="Phone number already in use by another account") + update_fields["phone"] = profile_data.phone + + if profile_data.date_of_birth is not None: + update_fields["date_of_birth"] = profile_data.date_of_birth + + if profile_data.profile_picture is not None: + update_fields["profile_picture"] = profile_data.profile_picture + + if not update_fields: + raise HTTPException(status_code=400, detail="No fields to update") + + # Update profile + success = await BookMyServiceUserModel.update_profile(current_customer_id, update_fields) + + if success: + # Get updated profile + updated_user = await BookMyServiceUserModel.find_by_id(current_customer_id) + + profile_response = ProfileResponse( + customer_id=updated_user["customer_id"], + name=updated_user["name"], + email=updated_user.get("email"), + phone=updated_user.get("phone"), + date_of_birth=updated_user.get("date_of_birth"), + profile_picture=updated_user.get("profile_picture"), + auth_method=updated_user.get("auth_mode", "unknown"), + created_at=updated_user.get("created_at"), + updated_at=updated_user.get("updated_at") + ) + + return ProfileOperationResponse( + success=True, + message="Profile updated successfully", + profile=profile_response + ) + else: + return ProfileOperationResponse( + success=False, + message="Failed to update profile" + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error updating profile for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error" + ) + diff --git a/app/routers/review_router.py b/app/routers/review_router.py new file mode 100644 index 0000000000000000000000000000000000000000..72ee26ccf552ad9058bde1f9ad96cdce5a568185 --- /dev/null +++ b/app/routers/review_router.py @@ -0,0 +1,48 @@ +from fastapi import APIRouter, Depends, HTTPException, status +import logging +from app.utils.jwt import get_current_user +from typing import Dict, Any + +from app.schemas.review_schema import ReviewCreateRequest,ReviewResponse +from app.models.review_model import ReviewModel + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/add", response_model=ReviewResponse, + status_code=status.HTTP_201_CREATED) +async def create_review( + review_data: ReviewCreateRequest, + current_user: Dict[str, Any] = Depends(get_current_user) +): + try: + + + review_reponse=await ReviewModel.create_review(review_data.dict()) + + if not review_reponse: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="failed to add review details" + ) + return ReviewResponse( + merchant_id=review_reponse["merchant_id"], + location_id=review_reponse["location_id"], + user_name=review_reponse["user_name"], + rating=review_reponse["rating"], + review_text=review_reponse["review_text"], + review_date=review_reponse["review_date"], + verified_purchase=review_reponse["verified_purchase"] + + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error while adding review: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to add review" + ) + diff --git a/app/routers/user_router.py b/app/routers/user_router.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4e2ba5842031be877f7e8140f95a940525d9e5 --- /dev/null +++ b/app/routers/user_router.py @@ -0,0 +1,556 @@ +from fastapi import APIRouter, HTTPException, Depends, Security, status +from fastapi.security import APIKeyHeader + +from app.schemas.user_schema import ( + OTPRequest, + OTPRequestWithLogin, + OTPVerifyRequest, + UserRegisterRequest, + UserLoginRequest, + OAuthLoginRequest, + TokenResponse, + OTPSendResponse, +) +from app.services.user_service import UserService +from app.utils.jwt import create_temp_token, decode_token, create_refresh_token, get_current_customer_id +from app.utils.social_utils import verify_google_token, verify_google_access_token, verify_apple_token, verify_facebook_token +from app.utils.common_utils import validate_identifier +from app.models.social_security_model import SocialSecurityModel +from app.models.user_model import BookMyServiceUserModel +from fastapi import Request +import logging + +logger = logging.getLogger("user_router") + +router = APIRouter() + +# 🔐 Declare API key header scheme (Swagger shows a single token input box) +api_key_scheme = APIKeyHeader(name="Authorization", auto_error=False) + +# 🔍 Bearer token parser +# More flexible bearer token parser +def get_bearer_token(api_key: str = Security(api_key_scheme)) -> str: + try: + # Check if Authorization header is missing + if not api_key: + logger.warning("Missing Authorization header") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing Authorization header" + ) + + # If "Bearer " prefix is included, strip it + logger.info(f"Received Authorization header: {api_key}") + if api_key.lower().startswith("bearer "): + return api_key[7:] # Remove "Bearer " prefix + + # Else, assume it's already a raw JWT + return api_key + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logger.error(f"Error processing Authorization header: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid Authorization header format" + ) + + +# 📧📱 Send OTP using single login input (preferred endpoint) +@router.post("/send-otp-login", response_model=OTPSendResponse) +async def send_otp_login_handler(payload: OTPRequestWithLogin): + logger.info(f"OTP login request started - login_input: {payload.login_input}") + + try: + # Validate identifier format + try: + identifier_type = validate_identifier(payload.login_input) + logger.info(f"Login input type: {identifier_type}") + except ValueError as ve: + logger.error(f"Invalid login input format: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + + # Check if user already exists + user_exists = False + if identifier_type == "email": + user_exists = await BookMyServiceUserModel.exists_by_email_or_phone(email=payload.login_input) + elif identifier_type == "phone": + user_exists = await BookMyServiceUserModel.exists_by_email_or_phone(phone=payload.login_input) + + logger.info(f"User existence check result: {user_exists}") + + # Send OTP via service + logger.info(f"Calling UserService.send_otp with identifier: {payload.login_input}") + await UserService.send_otp(payload.login_input) + logger.info(f"OTP sent successfully to: {payload.login_input}") + + # Create temporary token + logger.info("Creating temporary token for OTP verification") + temp_token = create_temp_token({ + "sub": payload.login_input, + "type": "otp_verification" + }, expires_minutes=10) + + logger.info(f"Temporary token created for: {payload.login_input}") + logger.info(f"Temp token (first 20 chars): {temp_token[:20]}...") + + return { + "message": "OTP sent", + "temp_token": temp_token, + "user_exists": user_exists + } + + except HTTPException as e: + logger.error(f"OTP login request failed - HTTP {e.status_code}: {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error during OTP login request: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during OTP login request") + +# 🔐 OTP Login using temporary token +@router.post("/otp-login", response_model=TokenResponse) +async def otp_login_handler( + payload: OTPVerifyRequest, + request: Request, + temp_token: str = Depends(get_bearer_token) +): + logger.info(f"OTP login attempt started - login_input: {payload.login_input}, remember_me: {payload.remember_me}") + logger.info(f"Received temp_token: {temp_token[:20]}..." if temp_token else "No temp_token received") + + # Get client IP + client_ip = request.client.host if request.client else None + + try: + # Decode and validate temporary token + logger.info("Attempting to decode temporary token") + decoded = decode_token(temp_token) + logger.info(f"Decoded token payload: {decoded}") + + if not decoded: + logger.warning("Failed to decode temporary token - token is invalid or expired") + raise HTTPException(status_code=401, detail="Invalid or expired OTP session token") + + # Validate token subject matches login input + token_sub = decoded.get("sub") + token_type = decoded.get("type") + logger.info(f"Token subject: {token_sub}, Token type: {token_type}") + + if token_sub != payload.login_input: + logger.warning(f"Token subject mismatch - token_sub: {token_sub}, login_input: {payload.login_input}") + raise HTTPException(status_code=401, detail="Invalid or expired OTP session token") + + if token_type != "otp_verification": + logger.warning(f"Invalid token type - expected: otp_verification, got: {token_type}") + raise HTTPException(status_code=401, detail="Invalid or expired OTP session token") + + logger.info(f"Temporary token validation successful for: {payload.login_input}") + + # Call user service for OTP verification and login + logger.info(f"Calling UserService.otp_login_handler with identifier: {payload.login_input}, otp: {payload.otp}") + result = await UserService.otp_login_handler( + payload.login_input, + payload.otp, + client_ip=client_ip, + remember_me=payload.remember_me, + device_info=payload.device_info + ) + + logger.info(f"OTP login successful for: {payload.login_input}") + return result + + except HTTPException as e: + logger.error(f"OTP login failed for {payload.login_input} - HTTP {e.status_code}: {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error during OTP login for {payload.login_input}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during OTP login") + +# 🌐 OAuth Login for Google / Apple +@router.post("/oauth-login", response_model=TokenResponse) +async def oauth_login_handler(payload: OAuthLoginRequest, request: Request): + from app.core.config import settings + + # Get client IP + client_ip = request.client.host if request.client else None + + # Check if IP is locked for this provider + if await SocialSecurityModel.is_oauth_ip_locked(client_ip, payload.provider): + await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False) + raise HTTPException( + status_code=429, + detail=f"Too many failed attempts. IP temporarily locked for {payload.provider} OAuth." + ) + + # Check rate limiting + if not await SocialSecurityModel.check_oauth_rate_limit(client_ip, payload.provider): + await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False) + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded for {payload.provider} OAuth. Please try again later." + ) + + # Validate token format (allow Google access tokens too) + if not await SocialSecurityModel.validate_oauth_token_format(payload.token, payload.provider): + await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider) + await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False) + raise HTTPException(status_code=400, detail="Invalid token format") + + # Increment rate limit counter + await SocialSecurityModel.increment_oauth_rate_limit(client_ip, payload.provider) + + try: + if payload.provider == "google": + # Accept both ID tokens (JWT) and access tokens + token = payload.token + is_jwt = token.count('.') == 2 + if is_jwt: + # ID token verification requires the configured client id + if not settings.GOOGLE_CLIENT_ID: + raise HTTPException(status_code=500, detail="Google OAuth not configured") + user_info = await verify_google_token(token, settings.GOOGLE_CLIENT_ID) + else: + # Access token verification via UserInfo does not require client id + if token.lower().startswith("bearer "): + token = token[7:] + user_info = await verify_google_access_token(token) + provider_customer_id = user_info.get('sub', user_info.get('id')) + customer_id = f"google_{provider_customer_id}" + + elif payload.provider == "apple": + if not settings.APPLE_AUDIENCE: + raise HTTPException(status_code=500, detail="Apple OAuth not configured") + user_info = await verify_apple_token(payload.token, settings.APPLE_AUDIENCE) + provider_customer_id = user_info.get('sub', user_info.get('id')) + customer_id = f"apple_{provider_customer_id}" + + elif payload.provider == "facebook": + if not settings.FACEBOOK_APP_ID or not settings.FACEBOOK_APP_SECRET: + raise HTTPException(status_code=500, detail="Facebook OAuth not configured") + user_info = await verify_facebook_token(payload.token, settings.FACEBOOK_APP_ID, settings.FACEBOOK_APP_SECRET) + provider_customer_id = user_info.get('id') + customer_id = f"facebook_{provider_customer_id}" + + else: + raise HTTPException(status_code=400, detail="Unsupported OAuth provider") + + # Clear failed attempts on successful verification + await SocialSecurityModel.clear_oauth_failed_attempts(client_ip, payload.provider) + + # Log successful attempt + await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, True, customer_id) + + # Resolve existing UUID via social linkage if available + user_exists = False + resolved_customer_uuid = None + try: + from app.models.social_account_model import SocialAccountModel + linked_account = await SocialAccountModel.find_by_provider_and_customer_id(payload.provider, provider_customer_id) + if linked_account and linked_account.get("customer_id"): + resolved_customer_uuid = linked_account["customer_id"] + user_exists = True + else: + # Backward compatibility: some records may still have provider-prefixed customer_id + existing_user = await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}) + if existing_user: + resolved_customer_uuid = existing_user["customer_id"] + user_exists = True + except Exception: + # Do not block login on existence check errors; default to False + user_exists = False + + temp_token = create_temp_token({ + "sub": resolved_customer_uuid or customer_id, + "type": "oauth_session", + "verified": True, + "provider": payload.provider, + "user_info": user_info + }) + + # Log temporary OAuth session token (truncated) + logger.info(f"OAuth temp token generated (first 25 chars): {temp_token[:25]}...") + + # Populate response with available customer details for frontend convenience + return { + "access_token": temp_token, + "token_type": "bearer", + "expires_in": settings.JWT_TEMP_TOKEN_EXPIRE_MINUTES * 60, + "refresh_token": None, + "customer_id": resolved_customer_uuid if user_exists else None, + "name": user_info.get("name"), + "email": user_info.get("email"), + "profile_picture": user_info.get("picture"), + "auth_method": "oauth", + "provider": payload.provider, + "user_exists": user_exists, + "security_info": { + "verified": user_info.get("email_verified"), + "provider_user_id": provider_customer_id + } + } + + except HTTPException: + # Re-raise HTTP exceptions (configuration errors, etc.) + raise + except Exception as e: + # Track failed attempt for token verification failures + await SocialSecurityModel.track_oauth_failed_attempt(client_ip, payload.provider) + await SocialSecurityModel.log_oauth_attempt(client_ip, payload.provider, False) + logger.error(f"OAuth verification failed for {payload.provider}: {str(e)}", exc_info=True) + raise HTTPException(status_code=401, detail="OAuth token verification failed") + +# 👤 Final user registration after OTP or OAuth +@router.post("/register", response_model=TokenResponse) +async def register_user( + payload: UserRegisterRequest, + temp_token: str = Depends(get_bearer_token) +): + logger.info(f"Received registration request with payload: {payload}") + + decoded = decode_token(temp_token) + if not decoded or decoded.get("type") not in ["otp_verification", "oauth_session"]: + raise HTTPException(status_code=401, detail="Invalid or expired registration token") + + logger.info(f"Registering user with payload: {payload}") + + result = await UserService.register(payload, decoded) + # Log tokens returned on register (truncated) + if result and isinstance(result, dict): + at = result.get("access_token") + rt = result.get("refresh_token") + if at: + logger.info(f"Register access token (first 25 chars): {at[:25]}...") + if rt: + logger.info(f"Register refresh token (first 25 chars): {rt[:25]}...") + return result + +# 🔄 Refresh access token using refresh token with rotation +@router.post("/refresh-token", response_model=TokenResponse) +async def refresh_token_handler( + request: Request, + refresh_token: str = Depends(get_bearer_token) +): + from app.models.refresh_token_model import RefreshTokenModel + from app.utils.jwt import create_access_token + + logger.info("Refresh token request received") + + # Get client IP + client_ip = request.client.host if request.client else None + + try: + # Decode and validate refresh token + decoded = decode_token(refresh_token) + logger.info(f"Decoded refresh token payload: {decoded}") + + if not decoded: + logger.warning("Failed to decode refresh token - token is invalid or expired") + raise HTTPException(status_code=401, detail="Invalid or expired refresh token") + + # Validate token type + token_type = decoded.get("type") + if token_type != "refresh": + logger.warning(f"Invalid token type for refresh - expected: refresh, got: {token_type}") + raise HTTPException(status_code=401, detail="Invalid refresh token") + + # Extract token information + customer_id = decoded.get("sub") + token_id = decoded.get("jti") + family_id = decoded.get("family_id") + remember_me = decoded.get("remember_me", False) + + if not customer_id or not token_id: + logger.warning("Refresh token missing required claims") + raise HTTPException(status_code=401, detail="Invalid refresh token") + + # Validate token hasn't been revoked or reused + if not await RefreshTokenModel.is_token_valid(token_id): + logger.error(f"Token {token_id} is invalid - possible security breach") + raise HTTPException( + status_code=401, + detail="Invalid refresh token. Please login again." + ) + + # Mark current token as used + await RefreshTokenModel.mark_token_as_used(token_id) + + # Increment rotation count + if family_id: + await RefreshTokenModel.increment_rotation_count(family_id) + + logger.info(f"Refresh token validation successful for user: {customer_id}") + + # Create new access token + access_token = create_access_token({ + "sub": customer_id + }) + + # Create new refresh token (rotation) + new_refresh_token, new_token_id, new_expires_at = create_refresh_token( + {"sub": customer_id}, + remember_me=remember_me, + family_id=family_id + ) + + # Store new refresh token metadata + token_metadata = await RefreshTokenModel.get_token_metadata(token_id) + await RefreshTokenModel.store_refresh_token( + token_id=new_token_id, + customer_id=customer_id, + family_id=family_id, + expires_at=new_expires_at, + remember_me=remember_me, + device_info=token_metadata.get("device_info") if token_metadata else None, + ip_address=client_ip + ) + + logger.info(f"New tokens generated for user: {customer_id} (rotation)") + + return { + "access_token": access_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, + "refresh_token": new_refresh_token, + "customer_id": customer_id + } + + except HTTPException as e: + logger.error(f"Refresh token failed - HTTP {e.status_code}: {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error during refresh token: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during token refresh") + + +# 🚪 Logout - Revoke refresh token +@router.post("/logout") +async def logout_handler( + refresh_token: str = Depends(get_bearer_token) +): + from app.models.refresh_token_model import RefreshTokenModel + + logger.info("Logout request received") + + try: + # Decode refresh token to get token ID + decoded = decode_token(refresh_token) + + if decoded and decoded.get("type") == "refresh": + token_id = decoded.get("jti") + customer_id = decoded.get("sub") + + if token_id: + # Revoke the refresh token + await RefreshTokenModel.revoke_token(token_id) + logger.info(f"Revoked refresh token {token_id} for user {customer_id}") + + return { + "message": "Logged out successfully", + "success": True + } + + return { + "message": "Invalid token", + "success": False + } + + except Exception as e: + logger.error(f"Error during logout: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during logout") + +# 🚪 Logout from all devices - Revoke all refresh tokens +@router.post("/logout-all") +async def logout_all_handler( + customer_id: str = Depends(get_current_customer_id) +): + from app.models.refresh_token_model import RefreshTokenModel + + logger.info(f"Logout all devices request for user: {customer_id}") + + try: + # Revoke all refresh tokens for the user + revoked_count = await RefreshTokenModel.revoke_all_user_tokens(customer_id) + + logger.info(f"Revoked {revoked_count} tokens for user {customer_id}") + + return { + "message": f"Logged out from {revoked_count} device(s) successfully", + "success": True, + "revoked_count": revoked_count + } + + except Exception as e: + logger.error(f"Error during logout all: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during logout") + +# 📱 Get active sessions +@router.get("/sessions") +async def get_active_sessions_handler( + customer_id: str = Depends(get_current_customer_id) +): + from app.models.refresh_token_model import RefreshTokenModel + + logger.info(f"Get active sessions request for user: {customer_id}") + + try: + sessions = await RefreshTokenModel.get_active_sessions(customer_id) + + # Format session data for response + formatted_sessions = [] + for session in sessions: + formatted_sessions.append({ + "token_id": session.get("token_id"), + "device_info": session.get("device_info"), + "ip_address": session.get("ip_address"), + "created_at": session.get("created_at"), + "expires_at": session.get("expires_at"), + "remember_me": session.get("remember_me", False), + "last_used": session.get("used_at") + }) + + return { + "sessions": formatted_sessions, + "total": len(formatted_sessions) + } + + except Exception as e: + logger.error(f"Error getting active sessions: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error getting sessions") + +# 🗑️ Revoke specific session +@router.delete("/sessions/{token_id}") +async def revoke_session_handler( + token_id: str, + customer_id: str = Depends(get_current_customer_id) +): + from app.models.refresh_token_model import RefreshTokenModel + + logger.info(f"Revoke session request for token: {token_id} by user: {customer_id}") + + try: + # Verify the token belongs to the user + token_metadata = await RefreshTokenModel.get_token_metadata(token_id) + + if not token_metadata: + raise HTTPException(status_code=404, detail="Session not found") + + if token_metadata.get("customer_id") != customer_id: + raise HTTPException(status_code=403, detail="Unauthorized to revoke this session") + + # Revoke the token + success = await RefreshTokenModel.revoke_token(token_id) + + if success: + return { + "message": "Session revoked successfully", + "success": True + } + else: + raise HTTPException(status_code=404, detail="Session not found or already revoked") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error revoking session: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error revoking session") diff --git a/app/routers/wallet_router.py b/app/routers/wallet_router.py new file mode 100644 index 0000000000000000000000000000000000000000..1b08007eb7a2ed9f41476e2935895e713d6f2170 --- /dev/null +++ b/app/routers/wallet_router.py @@ -0,0 +1,223 @@ +from fastapi import APIRouter, Depends, HTTPException, status, Query +from typing import List, Optional +import logging + +from app.utils.jwt import get_current_customer_id +from app.services.wallet_service import WalletService +from app.schemas.wallet_schema import ( + WalletBalanceResponse, TransactionHistoryResponse, WalletSummaryResponse, + AddMoneyRequest, WithdrawMoneyRequest, TransactionRequest, TransactionResponse +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/balance", response_model=WalletBalanceResponse) +async def get_wallet_balance(current_customer_id: str = Depends(get_current_customer_id)): + """ + Get current user's wallet balance. + + This endpoint is JWT protected and requires a valid Bearer token. + """ + try: + logger.info(f"Wallet balance request for user: {current_customer_id}") + + balance_info = await WalletService.get_wallet_balance(current_customer_id) + return balance_info + + except Exception as e: + logger.error(f"Error getting wallet balance for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve wallet balance" + ) + +@router.get("/summary", response_model=WalletSummaryResponse) +async def get_wallet_summary(current_customer_id: str = Depends(get_current_customer_id)): + """ + Get wallet summary including balance and recent transaction stats. + """ + try: + logger.info(f"Wallet summary request for user: {current_customer_id}") + + summary = await WalletService.get_wallet_summary(current_customer_id) + return summary + + except Exception as e: + logger.error(f"Error getting wallet summary for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve wallet summary" + ) + +@router.get("/transactions", response_model=TransactionHistoryResponse) +async def get_transaction_history( + current_customer_id: str = Depends(get_current_customer_id), + page: int = Query(1, ge=1, description="Page number"), + limit: int = Query(10, ge=1, le=100, description="Number of transactions per page"), + transaction_type: Optional[str] = Query(None, description="Filter by transaction type") +): + """ + Get paginated transaction history for the current user. + + Query parameters: + - page: Page number (default: 1) + - limit: Number of transactions per page (default: 10, max: 100) + - transaction_type: Filter by type (credit, debit, refund, etc.) + """ + try: + logger.info(f"Transaction history request for user: {current_customer_id}, page: {page}, limit: {limit}") + + history = await WalletService.get_transaction_history( + current_customer_id, + page=page, + limit=limit, + transaction_type=transaction_type + ) + return history + + except Exception as e: + logger.error(f"Error getting transaction history for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to retrieve transaction history" + ) + +@router.post("/add-money", response_model=TransactionResponse) +async def add_money_to_wallet( + request: AddMoneyRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Add money to user's wallet. + + This would typically integrate with a payment gateway. + For now, it simulates adding money to the wallet. + """ + try: + logger.info(f"Add money request for user: {current_customer_id}, amount: {request.amount}") + + if request.amount <= 0: + raise HTTPException(status_code=400, detail="Amount must be greater than zero") + + transaction = await WalletService.add_money( + current_customer_id, + request.amount, + request.payment_method, + request.reference_id, + request.description + ) + + return TransactionResponse( + success=True, + message="Money added successfully", + transaction=transaction, + new_balance=transaction.balance_after + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error adding money for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to add money to wallet" + ) + +@router.post("/withdraw", response_model=TransactionResponse) +async def withdraw_money( + request: WithdrawMoneyRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Withdraw money from user's wallet. + + This would typically integrate with a payment gateway for bank transfers. + """ + try: + logger.info(f"Withdraw request for user: {current_customer_id}, amount: {request.amount}") + + if request.amount <= 0: + raise HTTPException(status_code=400, detail="Amount must be greater than zero") + + # Check if user has sufficient balance + balance_info = await WalletService.get_wallet_balance(current_customer_id) + if balance_info.balance < request.amount: + raise HTTPException(status_code=400, detail="Insufficient wallet balance") + + transaction = await WalletService.deduct_money( + current_customer_id, + request.amount, + "withdrawal", + request.bank_account_id, + request.description or "Wallet withdrawal" + ) + + return TransactionResponse( + success=True, + message="Withdrawal processed successfully", + transaction=transaction, + new_balance=transaction.balance_after + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error processing withdrawal for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process withdrawal" + ) + +@router.post("/transaction", response_model=TransactionResponse) +async def create_transaction( + request: TransactionRequest, + current_customer_id: str = Depends(get_current_customer_id) +): + """ + Create a generic transaction (for internal use, service bookings, etc.). + """ + try: + logger.info(f"Transaction request for user: {current_customer_id}, type: {request.transaction_type}, amount: {request.amount}") + + if request.amount <= 0: + raise HTTPException(status_code=400, detail="Amount must be greater than zero") + + if request.transaction_type == "debit": + # Check if user has sufficient balance for debit transactions + balance_info = await WalletService.get_wallet_balance(current_customer_id) + if balance_info.balance < request.amount: + raise HTTPException(status_code=400, detail="Insufficient wallet balance") + + transaction = await WalletService.deduct_money( + current_customer_id, + request.amount, + request.category or "service", + request.reference_id, + request.description + ) + else: # credit or refund + transaction = await WalletService.add_money( + current_customer_id, + request.amount, + request.category or "refund", + request.reference_id, + request.description + ) + + return TransactionResponse( + success=True, + message=f"Transaction processed successfully", + transaction=transaction, + new_balance=transaction.balance_after + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating transaction for user {current_customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process transaction" + ) \ No newline at end of file diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/schemas/address_schema.py b/app/schemas/address_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca9eb215c4be0fc8bb6a843412f8b2f00bf7681 --- /dev/null +++ b/app/schemas/address_schema.py @@ -0,0 +1,74 @@ +from pydantic import BaseModel, Field, validator +from datetime import datetime +from typing import List, Optional, Literal + +class AddressCreateRequest(BaseModel): + """Request model for creating a new address""" + address_line_1: str = Field(..., min_length=5, max_length=200, description="Primary address line") + address_line_2: Optional[str] = Field("", max_length=200, description="Secondary address line") + city: str = Field(..., min_length=2, max_length=100, description="City name") + state: str = Field(..., min_length=2, max_length=100, description="State name") + postal_code: str = Field(..., min_length=5, max_length=10, description="Postal/ZIP code") + country: str = Field(default="India", max_length=100, description="Country name") + address_type: Literal["home", "work", "other"] = Field(default="home", description="Type of address") + is_default: bool = Field(default=False, description="Set as default address") + landmark: Optional[str] = Field("", max_length=200, description="Nearby landmark") + + @validator('postal_code') + def validate_postal_code(cls, v): + if not v.isdigit(): + raise ValueError('Postal code must contain only digits') + return v + +class AddressUpdateRequest(BaseModel): + """Request model for updating an existing address""" + address_line_1: Optional[str] = Field(None, min_length=5, max_length=200, description="Primary address line") + address_line_2: Optional[str] = Field(None, max_length=200, description="Secondary address line") + city: Optional[str] = Field(None, min_length=2, max_length=100, description="City name") + state: Optional[str] = Field(None, min_length=2, max_length=100, description="State name") + postal_code: Optional[str] = Field(None, min_length=5, max_length=10, description="Postal/ZIP code") + country: Optional[str] = Field(None, max_length=100, description="Country name") + address_type: Optional[Literal["home", "work", "other"]] = Field(None, description="Type of address") + is_default: Optional[bool] = Field(None, description="Set as default address") + landmark: Optional[str] = Field(None, max_length=200, description="Nearby landmark") + + @validator('postal_code') + def validate_postal_code(cls, v): + if v and not v.isdigit(): + raise ValueError('Postal code must contain only digits') + return v + +class AddressResponse(BaseModel): + """Response model for address data""" + address_id: str = Field(..., description="Unique address ID") + address_line_1: str = Field(..., description="Primary address line") + address_line_2: str = Field(..., description="Secondary address line") + city: str = Field(..., description="City name") + state: str = Field(..., description="State name") + postal_code: str = Field(..., description="Postal/ZIP code") + country: str = Field(..., description="Country name") + address_type: str = Field(..., description="Type of address") + is_default: bool = Field(..., description="Is default address") + landmark: str = Field(..., description="Nearby landmark") + created_at: datetime = Field(..., description="Address creation timestamp") + updated_at: datetime = Field(..., description="Address last update timestamp") + + class Config: + from_attributes = True + +class AddressListResponse(BaseModel): + """Response model for list of addresses""" + success: bool = Field(..., description="Operation success status") + message: str = Field(..., description="Response message") + addresses: Optional[List[AddressResponse]]= Field(None, description="List of user addresses") + total_count: Optional[int] = Field(0, description="Total number of addresses") + +class SetDefaultAddressRequest(BaseModel): + """Request model for setting default address""" + address_id: str = Field(..., description="Address ID to set as default") + +class AddressOperationResponse(BaseModel): + """Response model for address operations""" + success: bool = Field(..., description="Operation success status") + message: str = Field(..., description="Response message") + address: Optional[AddressResponse] = Field(None, description="Address ID if applicable") \ No newline at end of file diff --git a/app/schemas/favorite_schema.py b/app/schemas/favorite_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c149f274f84c4cbb68cac043216b562520181a91 --- /dev/null +++ b/app/schemas/favorite_schema.py @@ -0,0 +1,45 @@ +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime + +class FavoriteCreateRequest(BaseModel): + """Request schema for creating a favorite""" + merchant_id: str = Field(..., description="ID of the merchant to favorite") + source: str = Field(..., description="Source of the favorite action") + merchant_category: str = Field(..., description="Category of the merchant (salon, spa, pet_spa, etc.)") + merchant_name: str = Field(..., description="Name of the merchant for quick display") + notes: Optional[str] = Field(None, description="Optional note about this favorite") +class FavoriteUpdateRequest(BaseModel): + """Request schema for updating favorite notes""" + notes: Optional[str] = Field(None, description="Updated note about this favorite") + +class FavoriteResponse(BaseModel): + """Response schema for a favorite merchant""" + merchant_id: str = Field(..., description="ID of the merchant") + source: str = Field(..., description="Source of the favorite action") + merchant_category: str = Field(..., description="Category of the merchant") + merchant_name: str = Field(..., description="Name of the merchant") + added_at: datetime = Field(..., description="When the merchant was favorited") + notes: Optional[str] = Field(None, description="Optional note about this favorite") + +class FavoritesListResponse(BaseModel): + """Response schema for listing favorites""" + favorites: List[FavoriteResponse] = Field(..., description="List of favorite merchants") + total_count: int = Field(..., description="Total number of favorites") + limit: int = Field(..., description="Maximum number of items returned") + +class FavoriteStatusResponse(BaseModel): + """Response schema for checking favorite status""" + is_favorite: bool = Field(..., description="Whether the merchant is in favorites") + merchant_id: Optional[str] = Field(None, description="ID of the merchant if exists") + +class FavoriteSuccessResponse(BaseModel): + """Response schema for successful favorite operations""" + success: bool = Field(..., description="Operation success status") + message: str = Field(..., description="Operation result message") + merchant_id: Optional[str] = Field(None, description="ID of the merchant") + +class FavoriteDataResponse(BaseModel): + success: bool = Field(..., description="Operation success status") + message: str = Field(..., description="Operation result message") + favorite_data: Optional[FavoriteResponse]=Field(None, description="List of favorite merchants") \ No newline at end of file diff --git a/app/schemas/guest_schema.py b/app/schemas/guest_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..13f0ddf7dc0580a914ca37433cfa721a41d77b18 --- /dev/null +++ b/app/schemas/guest_schema.py @@ -0,0 +1,231 @@ +from pydantic import BaseModel, Field, validator, EmailStr +from typing import Optional, List +from datetime import datetime,date +from enum import Enum + +class GenderEnum(str, Enum): + MALE = "Male" + FEMALE = "Female" + OTHER = "Other" + +class RelationshipEnum(str, Enum): + FAMILY = "Family" + FRIEND = "Friend" + COLLEAGUE = "Colleague" + OTHER = "Other" + +class GuestCreateRequest(BaseModel): + """Schema for creating a new guest profile""" + first_name: str = Field(..., min_length=1, max_length=100, description="Guest's first name") + last_name: Optional[str] = Field(None, max_length=100, description="Guest's last name") + email: Optional[EmailStr] = Field(None, description="Guest's email address") + phone_number: Optional[str] = Field(None, max_length=20, description="Guest's phone number") + gender: Optional[GenderEnum] = Field(None, description="Guest's gender") + date_of_birth: Optional[date] = Field(None, description="Guest's date of birth for age calculation") + relationship: Optional[RelationshipEnum] = Field(None, description="Relationship to the user") + notes: Optional[str] = Field(None, max_length=500, description="Additional notes about the guest") + is_default: Optional[bool] = Field(None, description="Mark as default guest") + + @validator('email', pre=True) + def optional_email_empty_to_none(cls, v): + if v is None: + return None + if isinstance(v, str) and v.strip() == '': + return None + return v + + @validator('phone_number', pre=True) + def optional_phone_empty_to_none(cls, v): + if v is None: + return None + if isinstance(v, str) and v.strip() == '': + return None + return v + + @validator('date_of_birth', pre=True) + def coerce_date_of_birth(cls, v): + if v is None: + return None + if isinstance(v, str): + s = v.strip() + if s == '': + return None + try: + # Accept ISO date or datetime strings; convert to date + if 'T' in s or 'Z' in s or '+' in s: + return datetime.fromisoformat(s.replace('Z', '+00:00')).date() + return date.fromisoformat(s) + except Exception: + return v + if isinstance(v, datetime): + return v.date() + return v + + @validator('first_name') + def validate_first_name(cls, v): + if not v or not v.strip(): + raise ValueError('First name cannot be empty') + return v.strip() + + @validator('last_name') + def validate_last_name(cls, v): + if v is not None and v.strip() == '': + return None + return v.strip() if v else v + + @validator('phone_number') + def validate_phone_number(cls, v): + if v is not None: + # Remove spaces and special characters for validation + cleaned = ''.join(filter(str.isdigit, v)) + if len(cleaned) < 10 or len(cleaned) > 15: + raise ValueError('Phone number must be between 10 and 15 digits') + return v + + @validator('date_of_birth') + def validate_date_of_birth(cls, v): + if v is not None: + if v > date.today(): + raise ValueError('Date of birth cannot be in the future') + # Check if age would be reasonable (not more than 120 years old) + age = (date.today() - v).days // 365 + if age > 120: + raise ValueError('Date of birth indicates unrealistic age') + return v + +class GuestUpdateRequest(BaseModel): + """Schema for updating a guest profile""" + first_name: Optional[str] = Field(None, min_length=1, max_length=100, description="Guest's first name") + last_name: Optional[str] = Field(None, max_length=100, description="Guest's last name") + email: Optional[EmailStr] = Field(None, description="Guest's email address") + phone_number: Optional[str] = Field(None, max_length=20, description="Guest's phone number") + gender: Optional[GenderEnum] = Field(None, description="Guest's gender") + date_of_birth: Optional[date] = Field(None, description="Guest's date of birth for age calculation") + relationship: Optional[RelationshipEnum] = Field(None, description="Relationship to the user") + notes: Optional[str] = Field(None, max_length=500, description="Additional notes about the guest") + is_default: Optional[bool] = Field(None, description="Mark as default guest") + + @validator('email', pre=True) + def optional_email_empty_to_none_update(cls, v): + if v is None: + return None + if isinstance(v, str) and v.strip() == '': + return None + return v + + @validator('phone_number', pre=True) + def optional_phone_empty_to_none_update(cls, v): + if v is None: + return None + if isinstance(v, str) and v.strip() == '': + return None + return v + + @validator('date_of_birth', pre=True) + def coerce_date_of_birth_update(cls, v): + if v is None: + return None + if isinstance(v, str): + s = v.strip() + if s == '': + return None + try: + if 'T' in s or 'Z' in s or '+' in s: + return datetime.fromisoformat(s.replace('Z', '+00:00')).date() + return date.fromisoformat(s) + except Exception: + return v + if isinstance(v, datetime): + return v.date() + return v + + @validator('first_name') + def validate_first_name(cls, v): + if v is not None and (not v or not v.strip()): + raise ValueError('First name cannot be empty') + return v.strip() if v else v + + @validator('last_name') + def validate_last_name(cls, v): + if v is not None and v.strip() == '': + return None + return v.strip() if v else v + + @validator('phone_number') + def validate_phone_number(cls, v): + if v is not None: + # Remove spaces and special characters for validation + cleaned = ''.join(filter(str.isdigit, v)) + if len(cleaned) < 10 or len(cleaned) > 15: + raise ValueError('Phone number must be between 10 and 15 digits') + return v + + @validator('date_of_birth') + def validate_date_of_birth(cls, v): + if v is not None: + if v > date.today(): + raise ValueError('Date of birth cannot be in the future') + # Check if age would be reasonable (not more than 120 years old) + age = (date.today() - v).days // 365 + if age > 120: + raise ValueError('Date of birth indicates unrealistic age') + return v + +class GuestResponse(BaseModel): + """Schema for guest profile response""" + guest_id: str = Field(..., description="Unique guest identifier") + customer_id: str = Field(..., description="User ID who created this guest profile") + first_name: str = Field(..., description="Guest's first name") + last_name: Optional[str] = Field(None, description="Guest's last name") + email: Optional[str] = Field(None, description="Guest's email address") + phone_number: Optional[str] = Field(None, description="Guest's phone number") + gender: Optional[str] = Field(None, description="Guest's gender") + date_of_birth: Optional[date] = Field(None, description="Guest's date of birth") + relationship: Optional[str] = Field(None, description="Relationship to the user") + notes: Optional[str] = Field(None, description="Additional notes about the guest") + is_default: bool = Field(..., description="Is default guest") + created_at: datetime = Field(..., description="Guest profile creation timestamp") + updated_at: datetime = Field(..., description="Guest profile last update timestamp") + + @property + def full_name(self) -> str: + """Get the full name of the guest""" + if self.last_name: + return f"{self.first_name} {self.last_name}" + return self.first_name + + @property + def age(self) -> Optional[int]: + """Calculate age from date of birth""" + if self.date_of_birth: + today = datetime.now() + return today.year - self.date_of_birth.year - ( + (today.month, today.day) < (self.date_of_birth.month, self.date_of_birth.day) + ) + return None + + class Config: + from_attributes = True + json_encoders = { + datetime: lambda v: v.isoformat() + } + +class GuestListResponse(BaseModel): + """Schema for list of guests response""" + guests: List[GuestResponse] = Field(..., description="List of user's guests") + total_count: int = Field(..., description="Total number of guests") + + class Config: + from_attributes = True + +class GuestDeleteResponse(BaseModel): + """Schema for guest deletion response""" + message: str = Field(..., description="Deletion confirmation message") + guest_id: str = Field(..., description="ID of the deleted guest") + + class Config: + from_attributes = True + +class SetDefaultGuestRequest(BaseModel): + """Request model for setting default guest""" + guest_id: str = Field(..., description="Guest ID to set as default") \ No newline at end of file diff --git a/app/schemas/pet_schema.py b/app/schemas/pet_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bd7ae96e66a1a4e2e8ed414d1fe6528300d799 --- /dev/null +++ b/app/schemas/pet_schema.py @@ -0,0 +1,114 @@ +from pydantic import BaseModel, Field, validator +from typing import Optional, List +from datetime import datetime,date +from enum import Enum + +class SpeciesEnum(str, Enum): + DOG = "Dog" + CAT = "Cat" + OTHER = "Other" + +class GenderEnum(str, Enum): + MALE = "Male" + FEMALE = "Female" + OTHER = "Other" + +class TemperamentEnum(str, Enum): + CALM = "Calm" + NERVOUS = "Nervous" + AGGRESSIVE = "Aggressive" + SOCIAL = "Social" + +class PetCreateRequest(BaseModel): + """Schema for creating a new pet profile""" + pet_name: str = Field(..., min_length=1, max_length=100, description="Name of the pet") + species: SpeciesEnum = Field(..., description="Species of the pet") + breed: Optional[str] = Field(None, max_length=100, description="Breed of the pet") + date_of_birth: Optional[date] = Field(None, description="Pet's date of birth") + age: Optional[int] = Field(None, ge=0, le=50, description="Pet's age in years") + weight: Optional[float] = Field(None, ge=0, le=200, description="Pet's weight in kg") + gender: Optional[GenderEnum] = Field(None, description="Pet's gender") + temperament: Optional[TemperamentEnum] = Field(None, description="Pet's temperament") + health_notes: Optional[str] = Field(None, max_length=1000, description="Health notes, allergies, medications") + is_vaccinated: bool = Field(False, description="Vaccination status") + pet_photo_url: Optional[str] = Field(None, max_length=500, description="URL to pet's photo") + + @validator('pet_name') + def validate_pet_name(cls, v): + if not v or not v.strip(): + raise ValueError('Pet name cannot be empty') + return v.strip() + + @validator('age', 'date_of_birth') + def validate_age_or_dob(cls, v, values): + # At least one of age or date_of_birth should be provided + if 'age' in values and 'date_of_birth' in values: + if not values.get('age') and not values.get('date_of_birth'): + raise ValueError('Either age or date of birth must be provided') + return v + +class PetUpdateRequest(BaseModel): + """Schema for updating a pet profile""" + pet_name: Optional[str] = Field(None, min_length=1, max_length=100, description="Name of the pet") + species: Optional[SpeciesEnum] = Field(None, description="Species of the pet") + breed: Optional[str] = Field(None, max_length=100, description="Breed of the pet") + date_of_birth: Optional[date] = Field(None, description="Pet's date of birth") + age: Optional[int] = Field(None, ge=0, le=50, description="Pet's age in years") + weight: Optional[float] = Field(None, ge=0, le=200, description="Pet's weight in kg") + gender: Optional[GenderEnum] = Field(None, description="Pet's gender") + temperament: Optional[TemperamentEnum] = Field(None, description="Pet's temperament") + health_notes: Optional[str] = Field(None, max_length=1000, description="Health notes, allergies, medications") + is_vaccinated: Optional[bool] = Field(None, description="Vaccination status") + pet_photo_url: Optional[str] = Field(None, max_length=500, description="URL to pet's photo") + is_default: Optional[bool] = Field(None, description="Mark as default pet") + + @validator('pet_name') + def validate_pet_name(cls, v): + if v is not None and (not v or not v.strip()): + raise ValueError('Pet name cannot be empty') + return v.strip() if v else v + +class PetResponse(BaseModel): + """Schema for pet profile response""" + pet_id: str = Field(..., description="Unique pet identifier") + customer_id: str = Field(..., description="Owner's user ID") + pet_name: str = Field(..., description="Name of the pet") + species: str = Field(..., description="Species of the pet") + breed: Optional[str] = Field(None, description="Breed of the pet") + date_of_birth: Optional[datetime] = Field(None, description="Pet's date of birth") + age: Optional[int] = Field(None, description="Pet's age in years") + weight: Optional[float] = Field(None, description="Pet's weight in kg") + gender: Optional[str] = Field(None, description="Pet's gender") + temperament: Optional[str] = Field(None, description="Pet's temperament") + health_notes: Optional[str] = Field(None, description="Health notes, allergies, medications") + is_vaccinated: bool = Field(..., description="Vaccination status") + pet_photo_url: Optional[str] = Field(None, description="URL to pet's photo") + is_default: bool = Field(..., description="Is default pet") + created_at: datetime = Field(..., description="Pet profile creation timestamp") + updated_at: datetime = Field(..., description="Pet profile last update timestamp") + + class Config: + from_attributes = True + json_encoders = { + datetime: lambda v: v.isoformat() + } + +class PetListResponse(BaseModel): + """Schema for list of pets response""" + pets: List[PetResponse] = Field(..., description="List of user's pets") + total_count: int = Field(..., description="Total number of pets") + + class Config: + from_attributes = True + +class PetDeleteResponse(BaseModel): + """Schema for pet deletion response""" + message: str = Field(..., description="Deletion confirmation message") + pet_id: str = Field(..., description="ID of the deleted pet") + + class Config: + from_attributes = True + +class SetDefaultPetRequest(BaseModel): + """Request model for setting default pet""" + pet_id: str = Field(..., description="Pet ID to set as default") \ No newline at end of file diff --git a/app/schemas/profile_schema.py b/app/schemas/profile_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..269883e14d04c1ea06d450517e02ee505277e22b --- /dev/null +++ b/app/schemas/profile_schema.py @@ -0,0 +1,91 @@ +from pydantic import BaseModel, EmailStr, Field, validator +from datetime import datetime, date +from typing import Optional, Dict, Any +import re + +class ProfileUpdateRequest(BaseModel): + """Request model for updating user profile""" + name: Optional[str] = Field(None, min_length=2, max_length=100, description="User's full name") + email: Optional[EmailStr] = Field(None, description="User's email address") + phone: Optional[str] = Field(None, description="User's phone number") + date_of_birth: Optional[str] = Field(None, description="Date of birth in DD/MM/YYYY format") + profile_picture: Optional[str] = Field(None, description="Profile picture URL") + + @validator('phone') + def validate_phone(cls, v): + if v is not None: + # Remove any non-digit characters + phone_digits = re.sub(r'\D', '', v) + if len(phone_digits) != 10: + raise ValueError('Phone number must be exactly 10 digits') + return phone_digits + return v + + @validator('date_of_birth') + def validate_date_of_birth(cls, v): + if v is not None: + try: + # Parse DD/MM/YYYY format + day, month, year = map(int, v.split('/')) + birth_date = date(year, month, day) + + # Check if date is not in the future + if birth_date > date.today(): + raise ValueError('Date of birth cannot be in the future') + + # Check if age is reasonable (not more than 120 years) + age = (date.today() - birth_date).days // 365 + if age > 120: + raise ValueError('Invalid date of birth') + + return v + except ValueError as e: + if "Invalid date of birth" in str(e) or "Date of birth cannot be in the future" in str(e): + raise e + raise ValueError('Date of birth must be in DD/MM/YYYY format') + return v + +class ProfileResponse(BaseModel): + """Response model for user profile""" + customer_id: str = Field(..., description="Unique user identifier") + name: str = Field(..., description="User's full name") + email: Optional[str] = Field(None, description="User's email address") + phone: Optional[str] = Field(None, description="User's phone number") + date_of_birth: Optional[str] = Field(None, description="Date of birth in DD/MM/YYYY format") + profile_picture: Optional[str] = Field(None, description="Profile picture URL") + auth_method: str = Field(..., description="Authentication method used") + created_at: datetime = Field(..., description="Account creation timestamp") + updated_at: Optional[datetime] = Field(None, description="Last profile update timestamp") + + class Config: + from_attributes = True + +class PersonalDetailsResponse(BaseModel): + """Response model for personal details section""" + first_name: str = Field(..., description="User's first name") + last_name: str = Field(..., description="User's last name") + email: str = Field(..., description="User's email address") + phone: str = Field(..., description="User's phone number") + date_of_birth: Optional[str] = Field(None, description="Date of birth in DD/MM/YYYY format") + + class Config: + from_attributes = True + +class ProfileOperationResponse(BaseModel): + """Response model for profile operations""" + success: bool = Field(..., description="Operation success status") + message: str = Field(..., description="Response message") + profile: Optional[ProfileResponse] = Field(None, description="Updated profile data") + +class WalletDisplayResponse(BaseModel): + """Response model for wallet display in profile""" + balance: float = Field(..., description="Current wallet balance") + formatted_balance: str = Field(..., description="Formatted balance with currency symbol") + currency: str = Field(default="INR", description="Currency code") + +class ProfileDashboardResponse(BaseModel): + """Complete response model for profile dashboard""" + personal_details: PersonalDetailsResponse = Field(..., description="Personal details") + wallet: WalletDisplayResponse = Field(..., description="Wallet information") + address_count: int = Field(..., description="Number of saved addresses") + has_default_address: bool = Field(..., description="Whether user has a default address set") \ No newline at end of file diff --git a/app/schemas/review_schema.py b/app/schemas/review_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3c549b225c6195abd0765f5994d313c0119761 --- /dev/null +++ b/app/schemas/review_schema.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime + +class ReviewCreateRequest(BaseModel): + """Request schema for creating a review""" + merchant_id: str = Field(..., description="Unique ID of the merchant") + location_id: str = Field(..., description="ID of the merchant location") + user_name: str = Field(..., description="Name of the user submitting the review") + rating: float = Field(..., description="Rating given by the user") + review_text: str = Field(..., description="review text provided by the user") + verified_purchase: bool = Field(..., description="Indicates if the review is from a verified purchase") + + +class ReviewResponse(BaseModel): + """Response schema for creating a review""" + merchant_id: str = Field(..., description="Unique ID of the merchant") + location_id: str = Field(..., description="ID of the merchant location") + user_name: str = Field(..., description="Name of the user submitting the review") + rating: float = Field(..., description="Rating given by the user") + review_text: str = Field(..., description="review text provided by the user") + review_date: datetime = Field(..., description="Date and time when the review was created") + verified_purchase: bool = Field(..., description="Indicates if the review is from a verified purchase") + diff --git a/app/schemas/user_schema.py b/app/schemas/user_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..cbefe08918681cbc3408117f93488ba7817783e2 --- /dev/null +++ b/app/schemas/user_schema.py @@ -0,0 +1,198 @@ +from pydantic import BaseModel, EmailStr, validator +from typing import Optional, Literal, List, Dict, Any +from datetime import datetime +import re + +# Used for OTP-based or OAuth-based user registration +class UserRegisterRequest(BaseModel): + name: str + email: EmailStr # Mandatory for all registration modes + phone: str # Mandatory for all registration modes (always used as OTP identifier) + otp: Optional[str] = None # Required for OTP mode + dob: Optional[str] = None # ISO format date string + oauth_token: Optional[str] = None # Required for OAuth mode + provider: Optional[Literal["google", "apple", "facebook"]] = None # Required for OAuth mode + mode: Literal["otp", "oauth"] + remember_me: Optional[bool] = False + device_info: Optional[str] = None + + @validator('phone') + def validate_phone(cls, v): + if v is not None: + # Remove all non-digit characters except + + cleaned = re.sub(r'[^\d+]', '', v) + # Check if it's a valid phone number (8-15 digits, optionally starting with +) + if not re.match(r'^\+?[1-9]\d{7,14}$', cleaned): + raise ValueError('Invalid phone number format') + return v + + @validator('mode') + def validate_mode_dependent_fields(cls, v, values): + if v == "otp": + if not values.get('otp'): + raise ValueError('OTP is required for OTP registration mode') + elif v == "oauth": + if not values.get('oauth_token'): + raise ValueError('OAuth token is required for OAuth registration mode') + if not values.get('provider'): + raise ValueError('Provider is required for OAuth registration mode') + return v + +# Used in login form (optional display name prefilled from local storage) +class UserLoginRequest(BaseModel): + name: Optional[str] = None + login_input: str # Changed from email to support both email and phone + + @validator('login_input') + def validate_login_input(cls, v): + # Check if it's either email or phone + email_pattern = r"[^@]+@[^@]+\.[^@]+" + phone_cleaned = re.sub(r'[^\d+]', '', v) + phone_pattern = r'^\+?[1-9]\d{7,14}$' + + if not (re.match(email_pattern, v) or re.match(phone_pattern, phone_cleaned)): + raise ValueError('Login input must be a valid email address or phone number') + return v + +# OTP request via email or phone - updated to ensure only one is provided +class OTPRequest(BaseModel): + email: Optional[EmailStr] = None + phone: Optional[str] = None + + @validator('phone') + def validate_phone(cls, v): + if v is not None: + # Remove all non-digit characters except + + cleaned = re.sub(r'[^\d+]', '', v) + # Check if it's a valid phone number (8-15 digits, optionally starting with +) + if not re.match(r'^\+?[1-9]\d{7,14}$', cleaned): + raise ValueError('Invalid phone number format') + return v + + @validator('phone') + def validate_only_one_identifier(cls, v, values): + if v is not None and values.get('email') is not None: + raise ValueError('Provide either email or phone, not both') + if v is None and values.get('email') is None: + raise ValueError('Either email or phone must be provided') + return v + +# Generic OTP request using single login input +class OTPRequestWithLogin(BaseModel): + login_input: str # email or phone + + @validator('login_input') + def validate_login_input(cls, v): + # Check if it's either email or phone + email_pattern = r"[^@]+@[^@]+\.[^@]+" + phone_cleaned = re.sub(r'[^\d+]', '', v) + phone_pattern = r'^\+?[1-9]\d{7,14}$' + + if not (re.match(email_pattern, v) or re.match(phone_pattern, phone_cleaned)): + raise ValueError('Login input must be a valid email address or phone number') + return v + +# OTP verification input +class OTPVerifyRequest(BaseModel): + login_input: str + otp: str + remember_me: Optional[bool] = False + device_info: Optional[str] = None + + @validator('login_input') + def validate_login_input(cls, v): + # Check if it's either email or phone + email_pattern = r"[^@]+@[^@]+\.[^@]+" + phone_cleaned = re.sub(r'[^\d+]', '', v) + phone_pattern = r'^\+?[1-9]\d{7,14}$' + + if not (re.match(email_pattern, v) or re.match(phone_pattern, phone_cleaned)): + raise ValueError('Login input must be a valid email address or phone number') + return v + +# OTP send response with user existence flag +class OTPSendResponse(BaseModel): + message: str + temp_token: str + user_exists: bool = False + +# OAuth login using Google/Apple +class OAuthLoginRequest(BaseModel): + provider: Literal["google", "apple", "facebook"] + token: str + remember_me: Optional[bool] = False + device_info: Optional[str] = None + +# JWT Token response format with enhanced security info +class TokenResponse(BaseModel): + access_token: str + token_type: str = "bearer" + expires_in: Optional[int] = None + refresh_token: Optional[str] = None + customer_id: Optional[str] = None + name: Optional[str] = None + email: Optional[str] = None + profile_picture: Optional[str] = None + auth_method: Optional[str] = None # "otp" or "oauth" + provider: Optional[str] = None # For OAuth logins + user_exists: Optional[bool] = None # Indicates if user already exists for OAuth + security_info: Optional[Dict[str, Any]] = None + +# Enhanced user profile response with social accounts +class UserProfileResponse(BaseModel): + customer_id: str + name: str + email: Optional[EmailStr] = None + phone: Optional[str] = None + profile_picture: Optional[str] = None + auth_method: str + created_at: datetime + social_accounts: Optional[List[Dict[str, Any]]] = None + security_info: Optional[Dict[str, Any]] = None + +# Social account information +class SocialAccountInfo(BaseModel): + provider: str + email: Optional[str] = None + name: Optional[str] = None + linked_at: datetime + last_login: Optional[datetime] = None + +# Social account summary response +class SocialAccountSummary(BaseModel): + linked_accounts: List[SocialAccountInfo] + total_accounts: int + profile_picture: Optional[str] = None + +# Account linking request +class LinkSocialAccountRequest(BaseModel): + provider: Literal["google", "apple", "facebook"] + token: str + +# Account unlinking request +class UnlinkSocialAccountRequest(BaseModel): + provider: Literal["google", "apple", "facebook"] + +# Login history entry +class LoginHistoryEntry(BaseModel): + timestamp: datetime + method: str # "otp" or "oauth" + provider: Optional[str] = None + ip_address: Optional[str] = None + success: bool + device_info: Optional[str] = None + +# Login history response +class LoginHistoryResponse(BaseModel): + entries: List[LoginHistoryEntry] + total_entries: int + page: int + per_page: int + +# Security settings response +class SecuritySettingsResponse(BaseModel): + two_factor_enabled: bool = False + linked_social_accounts: int + last_password_change: Optional[datetime] = None + recent_login_attempts: int + account_locked: bool = False \ No newline at end of file diff --git a/app/schemas/wallet_schema.py b/app/schemas/wallet_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..35b89ddb99dd76ac227e67df5abddd65c56e8282 --- /dev/null +++ b/app/schemas/wallet_schema.py @@ -0,0 +1,71 @@ +from pydantic import BaseModel, Field +from datetime import datetime +from typing import List, Optional, Literal +from decimal import Decimal + +class WalletBalanceResponse(BaseModel): + """Response model for wallet balance""" + balance: float = Field(..., description="Current wallet balance") + currency: str = Field(default="INR", description="Currency code") + formatted_balance: str = Field(..., description="Formatted balance with currency symbol") + +class TransactionEntry(BaseModel): + """Model for individual transaction entry""" + transaction_id: str = Field(..., description="Unique transaction ID") + amount: float = Field(..., description="Transaction amount") + transaction_type: Literal["credit", "debit", "refund", "cashback", "payment", "withdrawal"] = Field(..., description="Type of transaction") + description: str = Field(..., description="Transaction description") + reference_id: Optional[str] = Field(None, description="Reference ID for the transaction") + balance_before: float = Field(..., description="Balance before transaction") + balance_after: float = Field(..., description="Balance after transaction") + timestamp: datetime = Field(..., description="Transaction timestamp") + status: Literal["completed", "pending", "failed"] = Field(default="completed", description="Transaction status") + +class TransactionHistoryResponse(BaseModel): + """Response model for transaction history""" + transactions: List[TransactionEntry] = Field(..., description="List of transactions") + total_count: int = Field(..., description="Total number of transactions") + page: int = Field(..., description="Current page number") + per_page: int = Field(..., description="Number of items per page") + total_pages: int = Field(..., description="Total number of pages") + +class WalletSummaryResponse(BaseModel): + """Response model for wallet summary""" + balance: float = Field(..., description="Current wallet balance") + formatted_balance: str = Field(..., description="Formatted balance with currency symbol") + recent_transactions: List[TransactionEntry] = Field(..., description="Recent transactions") + +class AddMoneyRequest(BaseModel): + """Request model for adding money to wallet""" + amount: float = Field(..., gt=0, description="Amount to add (must be positive)") + payment_method: Literal["card", "upi", "netbanking"] = Field(..., description="Payment method") + reference_id: Optional[str] = Field(None, description="Payment reference ID") + description: Optional[str] = Field("Wallet top-up", description="Transaction description") + +class WithdrawMoneyRequest(BaseModel): + """Request model for withdrawing money from wallet""" + amount: float = Field(..., gt=0, description="Amount to withdraw (must be positive)") + bank_account_id: str = Field(..., description="Bank account ID for withdrawal") + description: Optional[str] = Field("Wallet withdrawal", description="Transaction description") + +class TransactionRequest(BaseModel): + """Generic transaction request model""" + amount: float = Field(..., gt=0, description="Transaction amount") + transaction_type: Literal["credit", "debit", "refund", "cashback", "payment"] = Field(..., description="Transaction type") + description: str = Field(..., description="Transaction description") + reference_id: Optional[str] = Field(None, description="Reference ID") + category: Optional[str] = Field(None, description="Transaction category") + +class TransactionResponse(BaseModel): + """Response model for transaction operations""" + success: bool = Field(..., description="Transaction success status") + message: str = Field(..., description="Response message") + transaction: Optional[TransactionEntry] = Field(None, description="Transaction details") + new_balance: Optional[float] = Field(None, description="New wallet balance after transaction") + +class WalletTransactionResponse(BaseModel): + """Response model for wallet transaction operations""" + success: bool = Field(..., description="Transaction success status") + message: str = Field(..., description="Response message") + transaction_id: Optional[str] = Field(None, description="Transaction ID if successful") + new_balance: Optional[float] = Field(None, description="New wallet balance after transaction") \ No newline at end of file diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/services/account_service.py b/app/services/account_service.py new file mode 100644 index 0000000000000000000000000000000000000000..d86e51df63a52ab7c44ddbde9ad5666c4881c411 --- /dev/null +++ b/app/services/account_service.py @@ -0,0 +1,396 @@ +from datetime import datetime, timedelta +from typing import List, Dict, Any, Optional +import logging +from bson import ObjectId + +from app.models.social_account_model import SocialAccountModel +from app.models.user_model import BookMyServiceUserModel +from app.schemas.user_schema import ( + SocialAccountSummary, SocialAccountInfo, LoginHistoryResponse, + LoginHistoryEntry, SecuritySettingsResponse +) +from app.utils.social_utils import verify_google_token, verify_apple_token, verify_facebook_token +from app.core.nosql_client import db + +# Configure logging +logger = logging.getLogger(__name__) + +class AccountService: + """Service for managing user accounts, social accounts, and security settings""" + + def __init__(self): + self.security_collection = db.get_collection("security_logs") + self.device_collection = db.get_collection("device_tracking") + self.session_collection = db.get_collection("user_sessions") + + async def get_social_account_summary(self, customer_id: str) -> SocialAccountSummary: + """Get summary of all linked social accounts for a user""" + try: + social_accounts = await SocialAccountModel.find_by_customer_id(customer_id) + + linked_accounts = [] + profile_picture = None + + for account in social_accounts: + account_info = SocialAccountInfo( + provider=account["provider"], + email=account.get("email"), + name=account.get("name"), + linked_at=account["created_at"], + last_login=account.get("last_login") + ) + linked_accounts.append(account_info) + + # Use the first available profile picture + if not profile_picture and account.get("profile_picture"): + profile_picture = account["profile_picture"] + + return SocialAccountSummary( + linked_accounts=linked_accounts, + total_accounts=len(linked_accounts), + profile_picture=profile_picture + ) + + except Exception as e: + logger.error(f"Error getting social account summary for user {customer_id}: {str(e)}") + raise + + async def link_social_account(self, customer_id: str, provider: str, token: str, client_ip: str) -> Dict[str, Any]: + """Link a new social account to an existing user""" + try: + # Verify the token and get user info + user_info = await self._verify_social_token(provider, token) + + # Check if this social account is already linked to another user + existing_account = await SocialAccountModel.find_by_provider_id( + provider, user_info["id"] + ) + + if existing_account and existing_account["customer_id"] != customer_id: + raise ValueError(f"This {provider} account is already linked to another user") + + # Check if user already has this provider linked + user_provider_account = await SocialAccountModel.find_by_user_and_provider( + customer_id, provider + ) + + if user_provider_account: + # Update existing account + await SocialAccountModel.update_social_account( + customer_id, provider, user_info, client_ip + ) + action = "updated" + else: + # Create new social account link + await SocialAccountModel.create_social_account( + customer_id, provider, user_info, client_ip + ) + action = "linked" + + # Log the action + await self._log_account_action( + customer_id, f"social_account_{action}", + {"provider": provider, "client_ip": client_ip} + ) + + return {"action": action, "provider": provider, "user_info": user_info} + + except Exception as e: + logger.error(f"Error linking social account for user {customer_id}: {str(e)}") + raise + + async def unlink_social_account(self, customer_id: str, provider: str) -> Dict[str, Any]: + """Unlink a social account from a user""" + try: + # Check if account exists + account = await SocialAccountModel.find_by_user_and_provider(customer_id, provider) + if not account: + raise ValueError(f"No {provider} account found for this user") + + # Check if this is the only authentication method + user = await BookMyServiceUserModel.find_by_id(customer_id) + if not user: + raise ValueError("User not found") + + # Count total social accounts + social_accounts = await SocialAccountModel.find_by_customer_id(customer_id) + + # If user has no phone/email and this is their only social account, prevent unlinking + if (len(social_accounts) == 1 and + not user.get("phone") and not user.get("email")): + raise ValueError("Cannot unlink the only authentication method") + + # Unlink the account + result = await SocialAccountModel.unlink_social_account(customer_id, provider) + + # Log the action + await self._log_account_action( + customer_id, "social_account_unlinked", + {"provider": provider} + ) + + return {"action": "unlinked", "provider": provider, "result": result} + + except Exception as e: + logger.error(f"Error unlinking social account for user {customer_id}: {str(e)}") + raise + + async def get_login_history(self, customer_id: str, page: int = 1, + per_page: int = 10, days: int = 30) -> LoginHistoryResponse: + """Get login history for a user""" + try: + # Calculate date range + end_date = datetime.utcnow() + start_date = end_date - timedelta(days=days) + + # Query security logs for login events + skip = (page - 1) * per_page + + pipeline = [ + { + "$match": { + "customer_id": customer_id, + "timestamp": {"$gte": start_date, "$lte": end_date}, + "$or": [ + {"path": {"$regex": "/login"}}, + {"path": {"$regex": "/oauth"}}, + {"path": {"$regex": "/otp"}} + ] + } + }, + {"$sort": {"timestamp": -1}}, + {"$skip": skip}, + {"$limit": per_page} + ] + + cursor = self.security_collection.aggregate(pipeline) + logs = await cursor.to_list(length=per_page) + + # Count total entries + total_count = await self.security_collection.count_documents({ + "customer_id": customer_id, + "timestamp": {"$gte": start_date, "$lte": end_date}, + "$or": [ + {"path": {"$regex": "/login"}}, + {"path": {"$regex": "/oauth"}}, + {"path": {"$regex": "/otp"}} + ] + }) + + # Convert to response format + entries = [] + for log in logs: + method = "oauth" if "oauth" in log["path"] else "otp" + provider = None + + # Extract provider from query params if available + if method == "oauth" and log.get("query_params"): + provider = log["query_params"].get("provider") + + entry = LoginHistoryEntry( + timestamp=log["timestamp"], + method=method, + provider=provider, + ip_address=log.get("client_ip"), + success=log["status_code"] < 400, + device_info=log.get("device_info", {}).get("user_agent") + ) + entries.append(entry) + + return LoginHistoryResponse( + entries=entries, + total_entries=total_count, + page=page, + per_page=per_page + ) + + except Exception as e: + logger.error(f"Error getting login history for user {customer_id}: {str(e)}") + raise + + async def get_security_settings(self, customer_id: str) -> SecuritySettingsResponse: + """Get security settings and status for a user""" + try: + # Get user info + user = await BookMyServiceUserModel.find_by_id(customer_id) + if not user: + raise ValueError("User not found") + + # Count linked social accounts + social_accounts = await SocialAccountModel.find_by_customer_id(customer_id) + linked_accounts_count = len(social_accounts) + + # Get recent login attempts (last 24 hours) + yesterday = datetime.utcnow() - timedelta(days=1) + recent_attempts = await self.security_collection.count_documents({ + "customer_id": customer_id, + "timestamp": {"$gte": yesterday}, + "$or": [ + {"path": {"$regex": "/login"}}, + {"path": {"$regex": "/oauth"}}, + {"path": {"$regex": "/otp"}} + ] + }) + + # Check if account is locked (this would be implemented based on your locking logic) + account_locked = False # Implement based on your account locking mechanism + + return SecuritySettingsResponse( + two_factor_enabled=False, # Implement 2FA if needed + linked_social_accounts=linked_accounts_count, + last_password_change=None, # Implement if you have password functionality + recent_login_attempts=recent_attempts, + account_locked=account_locked + ) + + except Exception as e: + logger.error(f"Error getting security settings for user {customer_id}: {str(e)}") + raise + + async def merge_social_accounts(self, primary_customer_id: str, secondary_customer_id: str, + client_ip: str) -> Dict[str, Any]: + """Merge social accounts from secondary user to primary user""" + try: + # Get social accounts from secondary user + secondary_accounts = await SocialAccountModel.find_by_customer_id(secondary_customer_id) + + merged_count = 0 + for account in secondary_accounts: + # Check if primary user already has this provider + existing = await SocialAccountModel.find_by_user_and_provider( + primary_customer_id, account["provider"] + ) + + if not existing: + # Transfer the account to primary user + await SocialAccountModel.update_customer_id( + account["_id"], primary_customer_id + ) + merged_count += 1 + + # Log the merge action + await self._log_account_action( + primary_customer_id, "accounts_merged", + { + "secondary_customer_id": secondary_customer_id, + "merged_accounts": merged_count, + "client_ip": client_ip + } + ) + + return { + "merged_accounts": merged_count, + "primary_customer_id": primary_customer_id, + "secondary_customer_id": secondary_customer_id + } + + except Exception as e: + logger.error(f"Error merging accounts {secondary_customer_id} -> {primary_customer_id}: {str(e)}") + raise + + async def revoke_all_sessions(self, customer_id: str, client_ip: str) -> Dict[str, Any]: + """Revoke all active sessions for a user""" + try: + # In a real implementation, you'd have a sessions collection + # For now, we'll just log the action + await self._log_account_action( + customer_id, "all_sessions_revoked", + {"client_ip": client_ip} + ) + + # Here you would typically: + # 1. Delete all session tokens from database + # 2. Add tokens to a blacklist + # 3. Force re-authentication on next request + + return {"action": "revoked", "customer_id": customer_id} + + except Exception as e: + logger.error(f"Error revoking sessions for user {customer_id}: {str(e)}") + raise + + async def get_trusted_devices(self, customer_id: str) -> List[Dict[str, Any]]: + """Get list of trusted devices for a user""" + try: + cursor = self.device_collection.find({ + "customer_id": customer_id, + "is_trusted": True + }).sort("last_seen", -1) + + devices = await cursor.to_list(length=None) + + # Format device information + trusted_devices = [] + for device in devices: + device_info = { + "device_id": str(device["_id"]), + "device_fingerprint": device["device_fingerprint"], + "platform": device.get("device_info", {}).get("platform", "Unknown"), + "browser": device.get("device_info", {}).get("browser", "Unknown"), + "first_seen": device["first_seen"], + "last_seen": device["last_seen"], + "access_count": device.get("access_count", 0) + } + trusted_devices.append(device_info) + + return trusted_devices + + except Exception as e: + logger.error(f"Error getting trusted devices for user {customer_id}: {str(e)}") + raise + + async def remove_trusted_device(self, customer_id: str, device_id: str) -> Dict[str, Any]: + """Remove a trusted device""" + try: + result = await self.device_collection.update_one( + { + "_id": ObjectId(device_id), + "customer_id": customer_id + }, + {"$set": {"is_trusted": False}} + ) + + if result.matched_count == 0: + raise ValueError("Device not found or not owned by user") + + await self._log_account_action( + customer_id, "trusted_device_removed", + {"device_id": device_id} + ) + + return {"action": "removed", "device_id": device_id} + + except Exception as e: + logger.error(f"Error removing trusted device for user {customer_id}: {str(e)}") + raise + + async def _verify_social_token(self, provider: str, token: str) -> Dict[str, Any]: + """Verify social media token and return user info""" + try: + if provider == "google": + return await verify_google_token(token) + elif provider == "apple": + return await verify_apple_token(token) + elif provider == "facebook": + return await verify_facebook_token(token) + else: + raise ValueError(f"Unsupported provider: {provider}") + except Exception as e: + logger.error(f"Token verification failed for {provider}: {str(e)}") + raise ValueError(f"Invalid {provider} token") + + async def _log_account_action(self, customer_id: str, action: str, details: Dict[str, Any]): + """Log account-related actions for audit purposes""" + try: + log_entry = { + "timestamp": datetime.utcnow(), + "customer_id": customer_id, + "action": action, + "details": details, + "type": "account_management" + } + + await self.security_collection.insert_one(log_entry) + + except Exception as e: + logger.error(f"Failed to log account action: {str(e)}") \ No newline at end of file diff --git a/app/services/favorite_service.py b/app/services/favorite_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7a539dbce1a9e1689a201002eab16169b68944 --- /dev/null +++ b/app/services/favorite_service.py @@ -0,0 +1,158 @@ +from fastapi import HTTPException, Depends +from typing import Optional +from app.models.favorite_model import BookMyServiceFavoriteModel +from app.schemas.favorite_schema import ( + FavoriteCreateRequest, + FavoriteUpdateRequest, + FavoriteResponse, + FavoritesListResponse, + FavoriteStatusResponse, + FavoriteSuccessResponse, + FavoriteDataResponse +) +from app.services.user_service import UserService +import logging + +logger = logging.getLogger("favorite_service") + +class FavoriteService: + + @staticmethod + async def add_favorite( + customer_id: str, + favorite_data: FavoriteCreateRequest + ) -> FavoriteSuccessResponse: + """Add a merchant to user's favorites""" + try: + # Create favorite in database + favorite_merchant_id= await BookMyServiceFavoriteModel.create_favorite( + customer_id, + favorite_data.dict() + ) + + return FavoriteSuccessResponse( + success=True, + message="Merchant added to favorites successfully", + merchant_id=favorite_merchant_id + ) + + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error adding favorite: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to add favorite") + + @staticmethod + async def remove_favorite(customer_id: str, merchant_id: str) -> FavoriteSuccessResponse: + """Remove a merchant from user's favorites""" + try: + await BookMyServiceFavoriteModel.delete_favorite(customer_id, merchant_id) + + return FavoriteSuccessResponse( + success=True, + message="Merchant removed from favorites successfully", + merchant_id=None + ) + + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error removing favorite: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to remove favorite") + + @staticmethod + async def get_favorites( + customer_id: str, + limit: int = 50 + ) -> FavoritesListResponse: + """Get user's favorite merchants""" + try: + result = await BookMyServiceFavoriteModel.get_favorites(customer_id,limit ) + + # Convert MongoDB documents to response format + favorites = [] + for fav in result["favorites"]: + favorites.append(FavoriteResponse( + merchant_id=fav["merchant_id"], + merchant_category=fav["merchant_category"], + merchant_name=fav["merchant_name"], + source=fav["source"], + added_at=fav["added_at"], + notes=fav.get("notes") + )) + + return FavoritesListResponse( + favorites=favorites, + total_count=result["total_count"], + limit=result["limit"] + ) + + except Exception as e: + logger.error(f"Error getting favorites: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to get favorites") + + @staticmethod + async def check_favorite_status(customer_id: str, merchant_id: str) -> FavoriteStatusResponse: + """Check if a merchant is in user's favorites""" + try: + favorite_data = await BookMyServiceFavoriteModel.get_favorite(customer_id, merchant_id) + + return FavoriteStatusResponse( + is_favorite=bool(favorite_data), + merchant_id=merchant_id + ) + + except Exception as e: + logger.error(f"Error checking favorite status: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to check favorite status") + + @staticmethod + async def update_favorite_notes( + customer_id: str, + merchant_id: str, + notes_data: FavoriteUpdateRequest + ) -> FavoriteSuccessResponse: + """Update notes for a favorite merchant""" + try: + await BookMyServiceFavoriteModel.update_favorite_notes( + customer_id=customer_id, + merchant_id=merchant_id, + notes=notes_data.notes + ) + + return FavoriteSuccessResponse( + success=True, + message="Favorite notes updated successfully", + merchant_id=merchant_id + ) + + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error updating favorite notes: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to update favorite notes") + + @staticmethod + async def get_favorite_details(customer_id: str, merchant_id: str)->FavoriteDataResponse: + """Get detailed information about a specific favorite""" + try: + favorite = await BookMyServiceFavoriteModel.get_favorite(customer_id, merchant_id) + + if favorite: + return FavoriteDataResponse( + success=True, + message="Favorite found successfully", + favorite_data=favorite + ) + else: + return FavoriteDataResponse( + success=False, + message="Favorite not found", + favorite_data=None + ) + + except HTTPException as e: + raise e + except Exception as e: + logger.error(f"Error getting favorite details: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to get favorite details") \ No newline at end of file diff --git a/app/services/otp_service.py b/app/services/otp_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b0fe1e124d307fede897e7a5ff5cee78e6064489 --- /dev/null +++ b/app/services/otp_service.py @@ -0,0 +1,36 @@ +from fastapi import HTTPException +from app.core.cache_client import get_redis +from app.utils.sms_utils import send_sms_otp +from app.utils.email_utils import send_email_otp +from app.utils.common_utils import is_email + +class BookMyServiceOTPModel: + OTP_TTL = 300 # 5 minutes + RATE_LIMIT_MAX = 3 + RATE_LIMIT_WINDOW = 600 # 10 minutes + + @staticmethod + async def store_otp(identifier: str, phone: str, otp: str, ttl: int = OTP_TTL): + redis = await get_redis() + + rate_key = f"otp_rate_limit:{identifier}" + attempts = await redis.incr(rate_key) + if attempts == 1: + await redis.expire(rate_key, BookMyServiceOTPModel.RATE_LIMIT_WINDOW) + elif attempts > BookMyServiceOTPModel.RATE_LIMIT_MAX: + raise HTTPException(status_code=429, detail="Too many OTP requests. Try again later.") + + await redis.setex(f"bms_otp:{identifier}", ttl, otp) + + + + + @staticmethod + async def verify_otp(identifier: str, otp: str): + redis = await get_redis() + key = f"bms_otp:{identifier}" + stored = await redis.get(key) + if stored and stored == otp: + await redis.delete(key) + return True + return False \ No newline at end of file diff --git a/app/services/profile_service.py b/app/services/profile_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0e05eecd10990ab15fd629e2ea86baf0e576264a --- /dev/null +++ b/app/services/profile_service.py @@ -0,0 +1,72 @@ +""" +Profile service for customer profile operations. +""" + +import logging +from typing import Optional, Dict, Any +from bson import ObjectId +from app.core.nosql_client import db +from fastapi import HTTPException, status + +logger = logging.getLogger(__name__) + +class ProfileService: + """Service class for profile-related operations.""" + + @staticmethod + async def get_customer_profile(customer_id: str) -> Dict[str, Any]: + """ + Fetch customer profile from the customers collection. + + Args: + customer_id (str): The user ID from JWT token + + Returns: + Dict[str, Any]: Customer profile data + + Raises: + HTTPException: If customer not found or database error + """ + try: + # Convert string ID to ObjectId if needed + if ObjectId.is_valid(customer_id): + query = {"_id": ObjectId(customer_id)} + else: + # If not a valid ObjectId, try searching by other fields + query = {"$or": [ + {"customer_id": customer_id}, + {"email": customer_id}, + {"mobile": customer_id} + ]} + + logger.info(f"Fetching profile for user: {customer_id}") + + # Fetch customer from customers collection + customer = await db.customers.find_one(query) + + if not customer: + logger.warning(f"Customer not found for customer_id: {customer_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Customer profile not found" + ) + + # Convert ObjectId to string for JSON serialization + if "_id" in customer: + customer["_id"] = str(customer["_id"]) + + logger.info(f"Successfully fetched profile for user: {customer_id}") + return customer + + except HTTPException: + # Re-raise HTTP exceptions + raise + except Exception as e: + logger.error(f"Error fetching customer profile for user {customer_id}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error while fetching profile" + ) + +# Create service instance +profile_service = ProfileService() \ No newline at end of file diff --git a/app/services/user_service.py b/app/services/user_service.py new file mode 100644 index 0000000000000000000000000000000000000000..23cc08e6d10f6bcd7204edf62f3900f0e5d82e9c --- /dev/null +++ b/app/services/user_service.py @@ -0,0 +1,350 @@ +import random +import uuid +from jose import jwt +from datetime import datetime, timedelta +from fastapi import HTTPException +from app.models.user_model import BookMyServiceUserModel +from app.models.otp_model import BookMyServiceOTPModel +from app.models.social_account_model import SocialAccountModel +from app.models.refresh_token_model import RefreshTokenModel +from app.core.config import settings +from app.utils.common_utils import is_email, validate_identifier +from app.utils.jwt import create_refresh_token +from app.schemas.user_schema import UserRegisterRequest +import logging + +logger = logging.getLogger("user_service") + + + + +class UserService: + @staticmethod + async def send_otp(identifier: str, phone: str = None, client_ip: str = None): + logger.info(f"UserService.send_otp called - identifier: {identifier}, phone: {phone}, ip: {client_ip}") + + try: + # Validate identifier format + identifier_type = validate_identifier(identifier) + logger.info(f"Identifier type: {identifier_type}") + + # Enhanced rate limiting by IP and identifier + if client_ip: + ip_rate_key = f"otp_ip_rate:{client_ip}" + ip_attempts = await BookMyServiceOTPModel.get_rate_limit_count(ip_rate_key) + if ip_attempts >= 10: # Max 10 OTPs per IP per hour + logger.warning(f"IP rate limit exceeded for {client_ip}") + raise HTTPException(status_code=429, detail="Too many OTP requests from this IP") + + # For phone identifiers, use the identifier itself as phone + # For email identifiers, use the provided phone parameter + if identifier_type == "phone": + phone_number = identifier + elif identifier_type == "email" and phone: + phone_number = phone + else: + # If email identifier but no phone provided, we'll send OTP via email + phone_number = None + + # Generate OTP - hardcoded for testing purposes + otp = '777777' + logger.info(f"Generated hardcoded OTP for identifier: {identifier}") + + await BookMyServiceOTPModel.store_otp(identifier, phone_number, otp) + + # Track IP-based rate limiting + if client_ip: + await BookMyServiceOTPModel.increment_rate_limit(ip_rate_key, 3600) # 1 hour window + + logger.info(f"OTP stored successfully for identifier: {identifier}") + logger.info(f"OTP sent to {identifier}") + + except ValueError as ve: + logger.error(f"Validation error for identifier {identifier}: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(f"Error in send_otp for identifier {identifier}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Failed to send OTP") + + @staticmethod + async def otp_login_handler( + identifier: str, + otp: str, + client_ip: str = None, + remember_me: bool = False, + device_info: str = None + ): + logger.info(f"UserService.otp_login_handler called - identifier: {identifier}, otp: {otp}, ip: {client_ip}, remember_me: {remember_me}") + + try: + # Validate identifier format + identifier_type = validate_identifier(identifier) + logger.info(f"Identifier type: {identifier_type}") + + # Check if account is locked + if await BookMyServiceOTPModel.is_account_locked(identifier): + logger.warning(f"Account locked for identifier: {identifier}") + raise HTTPException(status_code=423, detail="Account temporarily locked due to too many failed attempts") + + # Verify OTP with client IP tracking + logger.info(f"Verifying OTP for identifier: {identifier}") + otp_valid = await BookMyServiceOTPModel.verify_otp(identifier, otp, client_ip) + logger.info(f"OTP verification result: {otp_valid}") + + if not otp_valid: + logger.warning(f"Invalid or expired OTP for identifier: {identifier}") + # Track failed attempt + await BookMyServiceOTPModel.track_failed_attempt(identifier, client_ip) + raise HTTPException(status_code=400, detail="Invalid or expired OTP") + + # Clear failed attempts on successful verification + await BookMyServiceOTPModel.clear_failed_attempts(identifier) + logger.info(f"OTP verification successful for identifier: {identifier}") + + # Find user by identifier + logger.info(f"Looking up user by identifier: {identifier}") + user = await BookMyServiceUserModel.find_by_identifier(identifier) + logger.info(f"User lookup result: {user is not None}") + + if not user: + logger.warning(f"No user found for identifier: {identifier}") + raise HTTPException(status_code=404, detail="User not found") + + customer_id = user.get("customer_id") + logger.info(f"User found for identifier: {identifier}, customer_id: {customer_id}") + + # Create token family for refresh token rotation + family_id = await RefreshTokenModel.create_token_family(customer_id, device_info) + + # Create JWT access token + logger.info("Creating JWT token for authenticated user") + token_data = { + "sub": customer_id, + "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + } + + access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + + # Create refresh token with rotation support + refresh_token, token_id, expires_at = create_refresh_token( + {"sub": customer_id}, + remember_me=remember_me, + family_id=family_id + ) + + # Store refresh token metadata + await RefreshTokenModel.store_refresh_token( + token_id=token_id, + customer_id=customer_id, + family_id=family_id, + expires_at=expires_at, + remember_me=remember_me, + device_info=device_info, + ip_address=client_ip + ) + + # Log generated tokens (truncated for security) + logger.info(f"Access token generated (first 25 chars): {access_token[:25]}...") + logger.info(f"Refresh token generated (first 25 chars): {refresh_token[:25]}...") + logger.info(f"JWT tokens created successfully for user: {customer_id}") + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, + "customer_id": customer_id, + "name": user.get("name"), + "email": user.get("email"), + "profile_picture": user.get("profile_picture"), + "auth_method": user.get("auth_mode"), + "provider": None, + "security_info": None + } + + except ValueError as ve: + logger.error(f"Validation error for identifier {identifier}: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException as e: + logger.error(f"HTTP error in otp_login_handler for {identifier}: {e.status_code} - {e.detail}") + raise e + except Exception as e: + logger.error(f"Unexpected error in otp_login_handler for {identifier}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error during login") + + @staticmethod + async def register(data: UserRegisterRequest, decoded): + logger.info(f"Registering user with data: {data}") + + # Validate mandatory fields for all registration modes + if not data.name or not data.name.strip(): + raise HTTPException(status_code=400, detail="Name is required") + + if not data.email: + raise HTTPException(status_code=400, detail="Email is required") + + if not data.phone or not data.phone.strip(): + raise HTTPException(status_code=400, detail="Phone is required") + + if data.mode == "otp": + # Always use phone as the OTP identifier as per documentation + identifier = data.phone + + # Validate phone format + try: + identifier_type = validate_identifier(identifier) + if identifier_type != "phone": + raise ValueError("Phone number format is invalid") + logger.info(f"Registration identifier type: {identifier_type}") + except ValueError as ve: + logger.error(f"Invalid phone format during registration: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) + + redis_key = f"bms_otp:{identifier}" + logger.info(f"Verifying OTP for Redis key: {redis_key}") + + if not data.otp: + raise HTTPException(status_code=400, detail="OTP is required") + + if not await BookMyServiceOTPModel.verify_otp(identifier, data.otp): + raise HTTPException(status_code=400, detail="Invalid or expired OTP") + + customer_id = str(uuid.uuid4()) + + elif data.mode == "oauth": + # Validate OAuth-specific mandatory fields + if not data.oauth_token or not data.provider: + raise HTTPException(status_code=400, detail="OAuth token and provider are required") + + # Extract user info from decoded token + user_info = decoded.get("user_info", {}) + provider_customer_id = user_info.get("sub") or user_info.get("id") + + if not provider_customer_id: + raise HTTPException(status_code=400, detail="Invalid OAuth user information") + + # Check if this social account already exists + existing_social_account = await SocialAccountModel.find_by_provider_and_customer_id( + data.provider, provider_customer_id + ) + + if existing_social_account: + # User already has this social account linked + existing_user = await BookMyServiceUserModel.collection.find_one({ + "customer_id": existing_social_account["customer_id"] + }) + if existing_user: + # Update social account with latest info and return existing user token + await SocialAccountModel.update_social_account(data.provider, provider_customer_id, user_info) + + token_data = { + "sub": existing_user["customer_id"], + "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + } + access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + + # Create refresh token + refresh_token_data = { + "sub": existing_user["customer_id"], + "exp": datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) + } + refresh_token = jwt.encode(refresh_token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + + # Log generated tokens for existing linked user (truncated) + logger.info(f"Access token for existing user (first 25 chars): {access_token[:25]}...") + logger.info(f"Refresh token for existing user (first 25 chars): {refresh_token[:25]}...") + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + } + + # Generate a new UUID for customer_id instead of provider-prefixed ID + customer_id = str(uuid.uuid4()) + + else: + raise HTTPException(status_code=400, detail="Unsupported registration mode") + + # Check if user already exists + if await BookMyServiceUserModel.collection.find_one({"customer_id": customer_id}): + raise HTTPException(status_code=409, detail="User already registered") + + # Check for existing email or phone + existing_user = await BookMyServiceUserModel.exists_by_email_or_phone( + email=data.email, + phone=data.phone + ) + if existing_user: + raise HTTPException(status_code=409, detail="User with this email or phone already exists") + + # Create user document + user_doc = { + "customer_id": customer_id, + "name": data.name, + "email": data.email, + "phone": data.phone, + "auth_mode": data.mode, + "created_at": datetime.utcnow() + } + + # Add profile picture from social account if available + if data.mode == "oauth" and user_info.get("picture"): + user_doc["profile_picture"] = user_info["picture"] + + await BookMyServiceUserModel.collection.insert_one(user_doc) + logger.info(f"Created new user: {customer_id}") + + # Create social account record for OAuth registration using UUID customer_id + if data.mode == "oauth": + await SocialAccountModel.create_social_account( + customer_id, data.provider, provider_customer_id, user_info + ) + logger.info(f"Created social account link for {data.provider} -> {customer_id}") + + # Create token family for refresh token rotation + family_id = await RefreshTokenModel.create_token_family(customer_id, data.device_info) + + token_data = { + "sub": customer_id, + "exp": datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + } + + access_token = jwt.encode(token_data, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + + # Create refresh token with rotation support + refresh_token, token_id, expires_at = create_refresh_token( + {"sub": customer_id}, + remember_me=data.remember_me, + family_id=family_id + ) + + # Store refresh token metadata + await RefreshTokenModel.store_refresh_token( + token_id=token_id, + customer_id=customer_id, + family_id=family_id, + expires_at=expires_at, + remember_me=data.remember_me, + device_info=data.device_info, + ip_address=None # Can be passed from router if needed + ) + + # Log generated tokens for new registration (truncated) + logger.info(f"Access token on register (first 25 chars): {access_token[:25]}...") + logger.info(f"Refresh token on register (first 25 chars): {refresh_token[:25]}...") + + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "expires_in": settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, + "customer_id": customer_id, + "name": data.name, + "email": data.email, + "profile_picture": user_doc.get("profile_picture"), + "auth_method": data.mode, + "provider": data.provider if data.mode == "oauth" else None, + "security_info": None + } \ No newline at end of file diff --git a/app/services/wallet_service.py b/app/services/wallet_service.py new file mode 100644 index 0000000000000000000000000000000000000000..11c3da1fa541fdc4df57824cb14eb1faa865eb40 --- /dev/null +++ b/app/services/wallet_service.py @@ -0,0 +1,212 @@ +from typing import Dict, Any, Optional +import logging +from datetime import datetime + +from app.models.wallet_model import WalletModel +from app.schemas.wallet_schema import ( + WalletBalanceResponse, WalletSummaryResponse, TransactionHistoryResponse, + TransactionEntry, WalletTransactionResponse +) + +logger = logging.getLogger(__name__) + +class WalletService: + """Service for wallet operations""" + + @staticmethod + async def get_wallet_balance(customer_id: str) -> WalletBalanceResponse: + """Get formatted wallet balance for user""" + try: + balance = await WalletModel.get_wallet_balance(customer_id) + + return WalletBalanceResponse( + balance=balance, + currency="INR", + formatted_balance=f"₹{balance:,.2f}" + ) + + except Exception as e: + logger.error(f"Error getting wallet balance for user {customer_id}: {str(e)}") + return WalletBalanceResponse( + balance=0.0, + currency="INR", + formatted_balance="₹0.00" + ) + + @staticmethod + async def get_wallet_summary(customer_id: str) -> WalletSummaryResponse: + """Get wallet summary with balance and recent transactions""" + try: + summary_data = await WalletModel.get_wallet_summary(customer_id) + + # Convert transactions to schema format + recent_transactions = [] + for transaction in summary_data.get("recent_transactions", []): + recent_transactions.append(TransactionEntry( + transaction_id=transaction["_id"], + amount=transaction["amount"], + transaction_type=transaction["transaction_type"], + description=transaction["description"], + reference_id=transaction.get("reference_id"), + balance_before=transaction["balance_before"], + balance_after=transaction["balance_after"], + timestamp=transaction["timestamp"], + status=transaction["status"] + )) + + balance = summary_data.get("balance", 0.0) + + return WalletSummaryResponse( + balance=balance, + formatted_balance=f"₹{balance:,.2f}", + recent_transactions=recent_transactions + ) + + except Exception as e: + logger.error(f"Error getting wallet summary for user {customer_id}: {str(e)}") + return WalletSummaryResponse( + balance=0.0, + formatted_balance="₹0.00", + recent_transactions=[] + ) + + @staticmethod + async def get_transaction_history(customer_id: str, page: int = 1, per_page: int = 20) -> TransactionHistoryResponse: + """Get paginated transaction history""" + try: + history_data = await WalletModel.get_transaction_history(customer_id, page, per_page) + + # Convert transactions to schema format + transactions = [] + for transaction in history_data.get("transactions", []): + transactions.append(TransactionEntry( + transaction_id=transaction["_id"], + amount=transaction["amount"], + transaction_type=transaction["transaction_type"], + description=transaction["description"], + reference_id=transaction.get("reference_id"), + balance_before=transaction["balance_before"], + balance_after=transaction["balance_after"], + timestamp=transaction["timestamp"], + status=transaction["status"] + )) + + return TransactionHistoryResponse( + transactions=transactions, + total_count=history_data.get("total_count", 0), + page=page, + per_page=per_page, + total_pages=history_data.get("total_pages", 0) + ) + + except Exception as e: + logger.error(f"Error getting transaction history for user {customer_id}: {str(e)}") + return TransactionHistoryResponse( + transactions=[], + total_count=0, + page=page, + per_page=per_page, + total_pages=0 + ) + + @staticmethod + async def add_money(customer_id: str, amount: float, payment_method: str, + description: str = "Wallet top-up", reference_id: str = None) -> WalletTransactionResponse: + """Add money to wallet""" + try: + success = await WalletModel.update_balance( + customer_id=customer_id, + amount=amount, + transaction_type="credit", + description=f"{description} via {payment_method}", + reference_id=reference_id + ) + + if success: + new_balance = await WalletModel.get_wallet_balance(customer_id) + return WalletTransactionResponse( + success=True, + message=f"Successfully added ₹{amount:,.2f} to wallet", + transaction_id=reference_id, + new_balance=new_balance + ) + else: + return WalletTransactionResponse( + success=False, + message="Failed to add money to wallet" + ) + + except Exception as e: + logger.error(f"Error adding money to wallet for user {customer_id}: {str(e)}") + return WalletTransactionResponse( + success=False, + message="Internal error occurred while adding money" + ) + + @staticmethod + async def deduct_money(customer_id: str, amount: float, description: str, + reference_id: str = None) -> WalletTransactionResponse: + """Deduct money from wallet (for payments)""" + try: + success = await WalletModel.update_balance( + customer_id=customer_id, + amount=amount, + transaction_type="debit", + description=description, + reference_id=reference_id + ) + + if success: + new_balance = await WalletModel.get_wallet_balance(customer_id) + return WalletTransactionResponse( + success=True, + message=f"Successfully deducted ₹{amount:,.2f} from wallet", + transaction_id=reference_id, + new_balance=new_balance + ) + else: + return WalletTransactionResponse( + success=False, + message="Insufficient balance or transaction failed" + ) + + except Exception as e: + logger.error(f"Error deducting money from wallet for user {customer_id}: {str(e)}") + return WalletTransactionResponse( + success=False, + message="Internal error occurred while processing payment" + ) + + @staticmethod + async def process_refund(customer_id: str, amount: float, description: str, + reference_id: str = None) -> WalletTransactionResponse: + """Process refund to wallet""" + try: + success = await WalletModel.update_balance( + customer_id=customer_id, + amount=amount, + transaction_type="refund", + description=description, + reference_id=reference_id + ) + + if success: + new_balance = await WalletModel.get_wallet_balance(customer_id) + return WalletTransactionResponse( + success=True, + message=f"Refund of ₹{amount:,.2f} processed successfully", + transaction_id=reference_id, + new_balance=new_balance + ) + else: + return WalletTransactionResponse( + success=False, + message="Failed to process refund" + ) + + except Exception as e: + logger.error(f"Error processing refund for user {customer_id}: {str(e)}") + return WalletTransactionResponse( + success=False, + message="Internal error occurred while processing refund" + ) \ No newline at end of file diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/utils/common_utils.py b/app/utils/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aade825376db8956e177c9c9c9c8b150187ff869 --- /dev/null +++ b/app/utils/common_utils.py @@ -0,0 +1,31 @@ +import re + +def is_email(identifier: str) -> bool: + return re.match(r"[^@]+@[^@]+\.[^@]+", identifier) is not None + +def is_phone(identifier: str) -> bool: + """ + Validate phone number format. Supports: + - International format: +1234567890, +91-9876543210 + - National format: 9876543210, (123) 456-7890 + - With/without spaces, dashes, parentheses + """ + # Remove all non-digit characters except + + cleaned = re.sub(r'[^\d+]', '', identifier) + + # Check if it's a valid phone number (8-15 digits, optionally starting with +) + if re.match(r'^\+?[1-9]\d{7,14}$', cleaned): + return True + return False + +def validate_identifier(identifier: str) -> str: + """ + Validate and return the type of identifier (email or phone). + Raises ValueError if neither email nor phone format. + """ + if is_email(identifier): + return "email" + elif is_phone(identifier): + return "phone" + else: + raise ValueError("Identifier must be a valid email address or phone number") \ No newline at end of file diff --git a/app/utils/db.py b/app/utils/db.py new file mode 100644 index 0000000000000000000000000000000000000000..2e65291b56af14f47d73e7cc4dabb06bef3ed0bb --- /dev/null +++ b/app/utils/db.py @@ -0,0 +1,26 @@ +from datetime import datetime,date +from decimal import Decimal +from typing import Any +from pydantic import BaseModel + +def prepare_for_db(obj: Any) -> Any: + """ + Recursively sanitizes the object to be MongoDB-compatible: + - Converts Decimal to float + - Converts datetime with tzinfo to naive datetime + - Converts Pydantic BaseModel to dict + """ + if isinstance(obj, Decimal): + return float(obj) + elif isinstance(obj, date) and not isinstance(obj, datetime): + return datetime(obj.year, obj.month, obj.day) + elif isinstance(obj, datetime): + return obj.replace(tzinfo=None) + elif isinstance(obj, BaseModel): + return prepare_for_db(obj.dict()) + elif isinstance(obj, dict): + return {k: prepare_for_db(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [prepare_for_db(v) for v in obj] + else: + return obj \ No newline at end of file diff --git a/app/utils/email_utils.py b/app/utils/email_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..183017dfc2beba2f0447a4eb27df04f4ed9f440f --- /dev/null +++ b/app/utils/email_utils.py @@ -0,0 +1,16 @@ +import smtplib +from email.mime.text import MIMEText +from app.core.config import settings + +async def send_email_otp(to_email: str, otp: str, timeout: float = 10.0): + msg = MIMEText(f"Your OTP is {otp}. It is valid for 5 minutes.") + msg["Subject"] = "Your One-Time Password" + msg["From"] = settings.SMTP_FROM + msg["To"] = to_email + + server = smtplib.SMTP(settings.SMTP_HOST, settings.SMTP_PORT, timeout=timeout) + server.connect(settings.SMTP_HOST, settings.SMTP_PORT) + server.starttls() + server.login(settings.SMTP_USER, settings.SMTP_PASS) + server.send_message(msg) + server.quit() \ No newline at end of file diff --git a/app/utils/jwt.py b/app/utils/jwt.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c81418f057574477f5da0da09a7795cbe1be31 --- /dev/null +++ b/app/utils/jwt.py @@ -0,0 +1,137 @@ + +## bookmyservice-ums/app/utils/jwt.py + +from jose import jwt, JWTError +from datetime import datetime, timedelta +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from typing import Optional +from app.core.config import settings +import logging +import uuid + +SECRET_KEY = settings.JWT_SECRET_KEY +ALGORITHM = settings.JWT_ALGORITHM +ACCESS_EXPIRE_MINUTES_DEFAULT = settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES +REFRESH_EXPIRE_DAYS_DEFAULT = settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS +TEMP_EXPIRE_MINUTES_DEFAULT = settings.JWT_TEMP_TOKEN_EXPIRE_MINUTES + +# Remember me settings +REMEMBER_ME_REFRESH_EXPIRE_DAYS = settings.JWT_REMEMBER_ME_EXPIRE_DAYS + +# Security scheme +security = HTTPBearer() + +# Module logger (app-level logging config applies) +logger = logging.getLogger(__name__) + +def create_access_token(data: dict, expires_minutes: int = ACCESS_EXPIRE_MINUTES_DEFAULT): + to_encode = data.copy() + expire = datetime.utcnow() + timedelta(minutes=expires_minutes) + to_encode.update({"exp": expire}) + + # Avoid logging sensitive payload; log minimal context + logger.info( + "Creating access token", + ) + logger.info( + "Access token claims keys=%s expires_at=%s", + list(to_encode.keys()), + expire.isoformat(), + ) + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + +def create_refresh_token( + data: dict, + expires_days: int = REFRESH_EXPIRE_DAYS_DEFAULT, + remember_me: bool = False, + family_id: Optional[str] = None +): + """Create refresh token with rotation support""" + to_encode = data.copy() + + # Use longer expiry for remember me + if remember_me: + expires_days = REMEMBER_ME_REFRESH_EXPIRE_DAYS + + expire = datetime.utcnow() + timedelta(days=expires_days) + + # Generate unique token ID for tracking + token_id = str(uuid.uuid4()) + + to_encode.update({ + "exp": expire, + "type": "refresh", + "jti": token_id, # JWT ID for token tracking + "remember_me": remember_me + }) + + # Add family ID for rotation tracking + if family_id: + to_encode["family_id"] = family_id + + logger.info("Creating refresh token") + logger.info( + "Refresh token claims keys=%s expires_at=%s remember_me=%s", + list(to_encode.keys()), + expire.isoformat(), + remember_me + ) + + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM), token_id, expire + +def create_temp_token(data: dict, expires_minutes: int = TEMP_EXPIRE_MINUTES_DEFAULT): + logger.info("Creating temporary access token with short expiry") + return create_access_token(data, expires_minutes=expires_minutes) + +def decode_token(token: str) -> dict: + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + logger.info("Token decoded successfully") + logger.info("Decoded claims keys=%s", list(payload.keys())) + return payload + except JWTError as e: + logger.warning("Token decode failed: %s", str(e)) + return {} + +def verify_token(token: str) -> dict: + """ + Verify and decode JWT token, raise HTTPException if invalid. + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + customer_id: str = payload.get("sub") + if customer_id is None: + logger.warning("Verified token missing 'sub' claim") + raise credentials_exception + logger.info("Token verified for subject") + logger.info("Verified claims keys=%s", list(payload.keys())) + return payload + except JWTError as e: + logger.error("Token verification failed: %s", str(e)) + raise credentials_exception + +async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict: + """ + Dependency to get current authenticated user from JWT token. + """ + token = credentials.credentials + logger.info("Authenticating request with Bearer token") + # Don't log raw tokens; log minimal metadata + logger.info("Bearer token length=%d", len(token) if token else 0) + return verify_token(token) + +async def get_current_customer_id(current_user: dict = Depends(get_current_user)) -> str: + """ + Dependency to get current user ID. + """ + customer_id = current_user.get("sub") + logger.info("Resolved current customer id") + logger.info("Current customer id=%s", customer_id) + return customer_id \ No newline at end of file diff --git a/app/utils/logger.py b/app/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..55ea5360f8ecb2b72eab469086fc1f54041fd322 --- /dev/null +++ b/app/utils/logger.py @@ -0,0 +1,32 @@ +import logging +import sys + +def setup_logger(name): + """ + Set up a logger with consistent formatting and settings + + Args: + name (str): The name for the logger, typically __name__ + + Returns: + logging.Logger: Configured logger instance + """ + logger = logging.getLogger(name) + + # Only configure handlers if they don't exist + if not logger.handlers: + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_format = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + console_handler.setFormatter(console_format) + logger.addHandler(console_handler) + + # Set level - could be read from environment variables + logger.setLevel(logging.INFO) + + # Prevent propagation to root logger to avoid duplicate logs + logger.propagate = False + + return logger diff --git a/app/utils/sms_utils.py b/app/utils/sms_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..124fbb67899af40dfe1bfd92168afc685317b6f2 --- /dev/null +++ b/app/utils/sms_utils.py @@ -0,0 +1,15 @@ +from twilio.rest import Client +from twilio.http.http_client import TwilioHttpClient +from app.core.config import settings + +def send_sms_otp(phone: str, otp: str) -> str: + http_client = TwilioHttpClient(timeout=10) # 10 seconds timeout + client = Client(settings.TWILIO_ACCOUNT_SID, settings.TWILIO_AUTH_TOKEN, http_client=http_client) + + message = client.messages.create( + from_=settings.TWILIO_SMS_FROM, + body=f"Your OTP is {otp}", + to=phone + ) + + return message.sid \ No newline at end of file diff --git a/app/utils/social_utils.py b/app/utils/social_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be5750c1b0ea4a67a6c022c41c031708cd7754a1 --- /dev/null +++ b/app/utils/social_utils.py @@ -0,0 +1,372 @@ +from google.oauth2 import id_token as google_id_token +from google.auth.transport import requests as google_requests +from jose import jwt as jose_jwt, JWTError, jwk +from jose.utils import base64url_decode +from typing import Dict, Optional +import httpx +import asyncio +from functools import partial, lru_cache +import logging +from datetime import datetime, timedelta + +logger = logging.getLogger(__name__) + +class TokenVerificationError(Exception): + """Custom exception for token verification errors""" + pass + +class FacebookTokenVerifier: + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + async def verify_token(self, token: str) -> Dict: + """ + Asynchronously verifies a Facebook access token and returns user data. + """ + try: + # First, verify the token with Facebook's debug endpoint + async with httpx.AsyncClient(timeout=10.0) as client: + # Verify token validity + debug_url = f"https://graph.facebook.com/debug_token" + debug_params = { + "input_token": token, + "access_token": f"{self.app_id}|{self.app_secret}" + } + + debug_response = await client.get(debug_url, params=debug_params) + debug_response.raise_for_status() + debug_data = debug_response.json() + + if not debug_data.get("data", {}).get("is_valid"): + raise TokenVerificationError("Invalid Facebook token") + + # Check if token is for our app + token_app_id = debug_data.get("data", {}).get("app_id") + if token_app_id != self.app_id: + raise TokenVerificationError("Token not for this app") + + # Get user data + user_url = "https://graph.facebook.com/me" + user_params = { + "access_token": token, + "fields": "id,name,email,picture.type(large)" + } + + user_response = await client.get(user_url, params=user_params) + user_response.raise_for_status() + user_data = user_response.json() + + # Validate required fields + if not user_data.get("id"): + raise TokenVerificationError("Missing user ID in Facebook response") + + logger.info(f"Successfully verified Facebook token for user: {user_data.get('email', user_data.get('id'))}") + return user_data + + except httpx.RequestError as e: + logger.error(f"Facebook token verification request failed: {str(e)}") + raise TokenVerificationError(f"Facebook API request failed: {str(e)}") + except Exception as e: + logger.error(f"Facebook token verification failed: {str(e)}") + raise TokenVerificationError(f"Invalid Facebook token: {str(e)}") + +class GoogleTokenVerifier: + def __init__(self, client_id: str): + self.client_id = client_id + + async def verify_token(self, token: str) -> Dict: + """ + Asynchronously verifies a Google ID token and returns the payload if valid. + """ + try: + loop = asyncio.get_event_loop() + # Run the sync method in a thread to avoid blocking + idinfo = await loop.run_in_executor( + None, + partial(google_id_token.verify_oauth2_token, token, google_requests.Request(), self.client_id) + ) + + # Validate issuer + if idinfo.get('iss') not in ['accounts.google.com', 'https://accounts.google.com']: + raise TokenVerificationError('Invalid issuer') + + # Additional validation + if idinfo.get('aud') != self.client_id: + raise TokenVerificationError('Invalid audience') + + # Check token expiration (extra safety) + exp = idinfo.get('exp') + if exp and datetime.fromtimestamp(exp) < datetime.utcnow(): + raise TokenVerificationError('Token has expired') + + logger.info(f"Successfully verified Google token for user: {idinfo.get('email')}") + return idinfo + + except Exception as e: + logger.error(f"Google token verification failed: {str(e)}") + raise TokenVerificationError(f"Invalid Google token: {str(e)}") + +class GoogleAccessTokenVerifier: + def __init__(self): + pass + + async def verify_access_token(self, access_token: str) -> Dict: + """ + Verify Google OAuth access token by calling the OIDC UserInfo endpoint. + Returns normalized user info containing at least 'sub' and 'email' if available. + """ + try: + headers = {"Authorization": f"Bearer {access_token}"} + async with httpx.AsyncClient(timeout=10.0) as client: + # Google OIDC UserInfo endpoint + resp = await client.get("https://openidconnect.googleapis.com/v1/userinfo", headers=headers) + resp.raise_for_status() + data = resp.json() + + # Basic sanity checks + if not data.get("sub") and not data.get("id"): + raise TokenVerificationError("Missing subject/id in Google userinfo response") + + # Normalize shape similar to ID token claims + user_info = { + "sub": data.get("sub") or data.get("id"), + "email": data.get("email"), + "email_verified": data.get("email_verified"), + "name": data.get("name"), + "picture": data.get("picture"), + } + logger.info(f"Successfully verified Google access token for user: {user_info.get('email')}") + return user_info + except httpx.HTTPStatusError as e: + logger.error(f"Google userinfo verification failed: {e.response.text}") + raise TokenVerificationError("Invalid Google access token") + except httpx.RequestError as e: + logger.error(f"Google userinfo request error: {str(e)}") + raise TokenVerificationError(f"Google userinfo request failed: {str(e)}") + except Exception as e: + logger.error(f"Google access token verification failed: {str(e)}") + raise TokenVerificationError(f"Invalid Google access token: {str(e)}") + +class AppleTokenVerifier: + def __init__(self, audience: str, cache_duration: int = 3600): + self.audience = audience + self.cache_duration = cache_duration + self._keys_cache = None + self._cache_timestamp = None + + async def _get_apple_keys(self) -> list: + """ + Fetch Apple's public keys with caching to reduce API calls. + """ + now = datetime.utcnow() + + # Check if cache is still valid + if (self._keys_cache and self._cache_timestamp and + now - self._cache_timestamp < timedelta(seconds=self.cache_duration)): + return self._keys_cache + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get('https://appleid.apple.com/auth/keys') + response.raise_for_status() + keys = response.json().get('keys', []) + + # Update cache + self._keys_cache = keys + self._cache_timestamp = now + + logger.info("Successfully fetched Apple public keys") + return keys + + except httpx.RequestError as e: + logger.error(f"Failed to fetch Apple keys: {str(e)}") + raise TokenVerificationError(f"Could not fetch Apple public keys: {str(e)}") + + async def verify_token(self, token: str) -> Dict: + """ + Asynchronously verifies an Apple identity token and returns the decoded payload. + """ + try: + # Fetch Apple's public keys + keys = await self._get_apple_keys() + + # Decode header to get kid and alg + header = jose_jwt.get_unverified_header(token) + kid = header.get('kid') + alg = header.get('alg') + + if not kid or not alg: + raise TokenVerificationError("Token header missing required fields") + + # Find matching key + key = next((k for k in keys if k['kid'] == kid and k['alg'] == alg), None) + if not key: + raise TokenVerificationError("Public key not found for Apple token") + + # Verify signature manually (additional safety) + public_key = jwk.construct(key) + message, encoded_sig = token.rsplit('.', 1) + decoded_sig = base64url_decode(encoded_sig.encode()) + + if not public_key.verify(message.encode(), decoded_sig): + raise TokenVerificationError("Invalid Apple token signature") + + # Decode and validate claims + claims = jose_jwt.decode( + token, + key, + algorithms=['RS256'], + audience=self.audience, + issuer='https://appleid.apple.com' + ) + + # Additional validation + if claims.get('aud') != self.audience: + raise TokenVerificationError('Invalid audience') + + logger.info(f"Successfully verified Apple token for user: {claims.get('sub')}") + return claims + + except JWTError as e: + logger.error(f"JWT error during Apple token verification: {str(e)}") + raise TokenVerificationError(f"Invalid Apple token: {str(e)}") + except Exception as e: + logger.error(f"Apple token verification failed: {str(e)}") + raise TokenVerificationError(f"Invalid Apple token: {str(e)}") + +# Factory class for easier usage +class OAuthVerifier: + def __init__(self, google_client_id: Optional[str] = None, apple_audience: Optional[str] = None, + facebook_app_id: Optional[str] = None, facebook_app_secret: Optional[str] = None): + self.google_verifier = GoogleTokenVerifier(google_client_id) if google_client_id else None + self.google_access_verifier = GoogleAccessTokenVerifier() + self.apple_verifier = AppleTokenVerifier(apple_audience) if apple_audience else None + self.facebook_verifier = FacebookTokenVerifier(facebook_app_id, facebook_app_secret) if facebook_app_id and facebook_app_secret else None + + async def verify_google_token(self, token: str) -> Dict: + if not self.google_verifier: + raise TokenVerificationError("Google verifier not configured") + return await self.google_verifier.verify_token(token) + + async def verify_google_access_token(self, token: str) -> Dict: + return await self.google_access_verifier.verify_access_token(token) + + async def verify_apple_token(self, token: str) -> Dict: + if not self.apple_verifier: + raise TokenVerificationError("Apple verifier not configured") + return await self.apple_verifier.verify_token(token) + + async def verify_facebook_token(self, token: str) -> Dict: + if not self.facebook_verifier: + raise TokenVerificationError("Facebook verifier not configured") + return await self.facebook_verifier.verify_token(token) + +# Convenience functions (backward compatibility) +async def verify_google_token(token: str, client_id: str) -> Dict: + """ + Asynchronously verifies a Google ID token and returns the payload if valid. + In local test mode, bypass external verification and synthesize minimal claims. + """ + try: + from app.core.config import settings + if getattr(settings, "OAUTH_TEST_MODE", False): + # Accept anything shaped like a JWT, otherwise synthesize from raw token + sub = "test-google-sub" + email = "test.user@example.com" + if token.count(".") == 2: + # Try to decode header/payload without verification to extract email/sub if present + try: + unverified = jose_jwt.get_unverified_claims(token) + sub = unverified.get("sub", sub) + email = unverified.get("email", email) + except Exception: + pass + return {"sub": sub, "email": email, "aud": client_id} + except Exception: + # Fall through to real verification if settings import fails + pass + verifier = GoogleTokenVerifier(client_id) + return await verifier.verify_token(token) + +async def verify_google_access_token(token: str) -> Dict: + """ + Asynchronously verifies a Google OAuth access token via the UserInfo endpoint. + In local test mode, bypass network call and return synthetic user info. + """ + try: + from app.core.config import settings + if getattr(settings, "OAUTH_TEST_MODE", False): + # Strip optional Bearer prefix + t = token + if t.lower().startswith("bearer "): + t = t[7:] + return { + "sub": "test-google-access", + "email": "test.access@example.com", + "email_verified": True, + "name": "Test Access", + "picture": None, + } + except Exception: + pass + verifier = GoogleAccessTokenVerifier() + return await verifier.verify_access_token(token) + +async def verify_apple_token(token: str, audience: str) -> Dict: + """ + Asynchronously verifies an Apple identity token and returns the decoded payload. + In local test mode, bypass external verification and synthesize minimal claims. + """ + try: + from app.core.config import settings + if getattr(settings, "OAUTH_TEST_MODE", False): + sub = "test-apple-sub" + if token.count(".") == 2: + try: + unverified = jose_jwt.get_unverified_claims(token) + sub = unverified.get("sub", sub) + except Exception: + pass + return {"sub": sub, "aud": audience} + except Exception: + pass + verifier = AppleTokenVerifier(audience) + return await verifier.verify_token(token) + +async def verify_facebook_token(token: str, app_id: str, app_secret: str) -> Dict: + """ + Asynchronously verifies a Facebook access token and returns user data. + In local test mode, bypass network call and return synthetic user data. + """ + try: + from app.core.config import settings + if getattr(settings, "OAUTH_TEST_MODE", False): + return {"id": "test-facebook-id", "name": "Test FB", "email": "fb.test@example.com"} + except Exception: + pass + verifier = FacebookTokenVerifier(app_id, app_secret) + return await verifier.verify_token(token) + +# Example usage +async def example_usage(): + # Initialize verifier + oauth_verifier = OAuthVerifier( + google_client_id="your-google-client-id.googleusercontent.com", + apple_audience="your.app.bundle.id" + ) + + try: + # Verify Google token + google_claims = await oauth_verifier.verify_google_token("google_id_token_here") + print(f"Google user: {google_claims.get('email')}") + + # Verify Apple token + apple_claims = await oauth_verifier.verify_apple_token("apple_id_token_here") + print(f"Apple user: {apple_claims.get('sub')}") + + except TokenVerificationError as e: + print(f"Verification failed: {e}") + +if __name__ == "__main__": + asyncio.run(example_usage()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 83050042ad1f6d089d33c29844d3996568b1731d..50f8e76d2dbcff78ca33f2205e91352551f565e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,6 @@ pandas>=1.5.0 bleach>=6.0.0 psutil>=5.9.0 spacy>=3.5.0 +twilio +google-auth