""" Database Service for PsyAdGenesis Handles MongoDB connection and CRUD operations """ from motor.motor_asyncio import AsyncIOMotorClient from typing import Optional, Dict, Any, List from datetime import datetime from bson import ObjectId import json import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config import settings class DatabaseService: """Async MongoDB database service for storing ad creatives.""" def __init__(self): self.client: Optional[AsyncIOMotorClient] = None self.db = None self.collection = None self.mongodb_url = settings.mongodb_url self.db_name = settings.mongodb_db_name @staticmethod def _datetime_to_iso_utc(dt: datetime) -> str: """ Convert a datetime object to ISO format string with UTC timezone indicator. Ensures the returned string always has 'Z' suffix to indicate UTC. """ iso_str = dt.isoformat() # Add 'Z' suffix if not present to indicate UTC if not iso_str.endswith('Z') and '+' not in iso_str and '-' not in iso_str[-6:]: return iso_str + "Z" return iso_str def _serialize_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: """ Serialize a MongoDB document for JSON response: - Convert ObjectId to string 'id' - Convert datetime objects to ISO strings with UTC indicator """ if not doc: return doc # Convert ObjectId to string if "_id" in doc: doc["id"] = str(doc["_id"]) del doc["_id"] # Convert datetime fields to ISO strings with UTC indicator datetime_fields = ["created_at", "updated_at"] for field in datetime_fields: if field in doc and isinstance(doc[field], datetime): doc[field] = self._datetime_to_iso_utc(doc[field]) return doc async def connect(self): """Create connection to MongoDB.""" if not self.mongodb_url: print("Warning: MONGODB_URL not configured. Database features disabled.") return False try: self.client = AsyncIOMotorClient(self.mongodb_url) # Test connection await self.client.admin.command('ping') self.db = self.client[self.db_name] self.collection = self.db["ad_creatives"] # Create indexes for better query performance await self.collection.create_index("niche") await self.collection.create_index("created_at") await self.collection.create_index([("niche", 1), ("created_at", -1)]) await self.collection.create_index("generation_method") await self.collection.create_index("username") # User-specific index await self.collection.create_index([("username", 1), ("created_at", -1)]) # User + date index print(f"✓ Connected to MongoDB database: {self.db_name}") return True except Exception as e: print(f"✗ Database connection failed: {e}") return False async def disconnect(self): """Close database connection.""" if self.client: self.client.close() print("Database connection closed") async def save_ad_creative( self, niche: str, title: str, headline: str, primary_text: str, description: str, body_story: str, cta: str, psychological_angle: str, why_it_works: str, username: str, # Required: username of the user creating the ad image_url: Optional[str] = None, r2_url: Optional[str] = None, image_filename: Optional[str] = None, image_model: Optional[str] = None, image_seed: Optional[int] = None, image_prompt: Optional[str] = None, angle_key: Optional[str] = None, angle_name: Optional[str] = None, angle_trigger: Optional[str] = None, angle_category: Optional[str] = None, concept_key: Optional[str] = None, concept_name: Optional[str] = None, concept_structure: Optional[str] = None, concept_visual: Optional[str] = None, concept_category: Optional[str] = None, generation_method: str = "standard", metadata: Optional[Dict[str, Any]] = None, ) -> Optional[str]: """ Save an ad creative to the database. Returns: The ID of the saved record, or None if save failed. """ if self.collection is None: print("Database not connected - skipping save") return None try: # Prepare document doc = { "niche": niche, "title": title, "headline": headline, "primary_text": primary_text, "description": description, "body_story": body_story, "cta": cta, "psychological_angle": psychological_angle, "why_it_works": why_it_works, "username": username, # Store username for filtering "image_url": image_url, "r2_url": r2_url, "image_filename": image_filename, "image_model": image_model, "image_seed": image_seed, "image_prompt": image_prompt, "angle_key": angle_key, "angle_name": angle_name, "angle_trigger": angle_trigger, "angle_category": angle_category, "concept_key": concept_key, "concept_name": concept_name, "concept_structure": concept_structure, "concept_visual": concept_visual, "concept_category": concept_category, "generation_method": generation_method, "metadata": metadata, "created_at": datetime.utcnow(), } # Remove None values to keep documents clean doc = {k: v for k, v in doc.items() if v is not None} result = await self.collection.insert_one(doc) return str(result.inserted_id) except Exception as e: print(f"Failed to save ad creative: {e}") return None async def get_ad_creative(self, ad_id: str, username: Optional[str] = None) -> Optional[Dict[str, Any]]: """ Get a single ad creative by ID. If username is provided, only returns the ad if it belongs to that user. """ if self.collection is None: return None try: # Try to convert to ObjectId if it's a valid ObjectId string try: object_id = ObjectId(ad_id) except: # If not a valid ObjectId, try as string object_id = ad_id query = {"_id": object_id} if username: query["username"] = username doc = await self.collection.find_one(query) if doc: return self._serialize_document(doc) return None except Exception as e: print(f"Failed to get ad creative: {e}") return None async def list_ad_creatives( self, username: str, # Required: filter by username niche: Optional[str] = None, generation_method: Optional[str] = None, limit: int = 50, offset: int = 0, ) -> tuple[List[Dict[str, Any]], int]: """ List ad creatives for a specific user with optional filtering. Returns (results, total_count). """ if self.collection is None: return [], 0 try: query = {"username": username} # Always filter by username if niche: query["niche"] = niche if generation_method: query["generation_method"] = generation_method # Get total count before pagination total_count = await self.collection.count_documents(query) cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit) docs = await cursor.to_list(length=limit) # Convert documents to dict format results = [self._serialize_document(doc) for doc in docs] return results, total_count except Exception as e: print(f"Failed to list ad creatives: {e}") return [], 0 async def update_ad_creative( self, ad_id: str, username: Optional[str] = None, image_url: Optional[str] = None, image_filename: Optional[str] = None, image_model: Optional[str] = None, image_prompt: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs ) -> bool: """ Update an ad creative by ID. If username is provided, only updates if the ad belongs to that user. Args: ad_id: ID of the ad to update username: Optional username to verify ownership image_url: New image URL image_filename: New image filename image_model: New image model image_prompt: New image prompt metadata: Metadata dict (will be merged with existing metadata) **kwargs: Additional fields to update Returns: True if update was successful, False otherwise """ if self.collection is None: return False try: # Try to convert to ObjectId if it's a valid ObjectId string try: object_id = ObjectId(ad_id) except: # If not a valid ObjectId, try as string object_id = ad_id query = {"_id": object_id} if username: query["username"] = username # Build update document update_doc = {"updated_at": datetime.utcnow()} if image_url is not None: update_doc["image_url"] = image_url if image_filename is not None: update_doc["image_filename"] = image_filename if image_model is not None: update_doc["image_model"] = image_model if image_prompt is not None: update_doc["image_prompt"] = image_prompt # Add any additional fields from kwargs for key, value in kwargs.items(): if value is not None: update_doc[key] = value # Handle metadata merge if metadata is not None: # Get existing ad to merge metadata existing_ad = await self.collection.find_one(query) if existing_ad: existing_metadata = existing_ad.get("metadata") or {} # Merge metadata (new values override old ones) merged_metadata = {**existing_metadata, **metadata} update_doc["metadata"] = merged_metadata print(f"Debug: Merging metadata. Existing: {existing_metadata}, New: {metadata}, Merged: {merged_metadata}") else: update_doc["metadata"] = metadata print(f"Debug: Setting new metadata (no existing ad found): {metadata}") result = await self.collection.update_one( query, {"$set": update_doc} ) return result.modified_count > 0 except Exception as e: print(f"Failed to update ad creative: {e}") return False async def delete_ad_creative(self, ad_id: str, username: Optional[str] = None) -> bool: """ Delete an ad creative by ID. If username is provided, only deletes if the ad belongs to that user. """ if self.collection is None: return False try: # Try to convert to ObjectId if it's a valid ObjectId string try: object_id = ObjectId(ad_id) except: # If not a valid ObjectId, try as string object_id = ad_id query = {"_id": object_id} if username: query["username"] = username result = await self.collection.delete_one(query) return result.deleted_count > 0 except Exception as e: print(f"Failed to delete ad creative: {e}") return False async def get_stats(self, username: Optional[str] = None) -> Dict[str, Any]: """ Get statistics about stored ad creatives. If username is provided, only returns stats for that user's ads. """ if self.collection is None: return {"connected": False} try: # Build base query base_query = {} if username: base_query["username"] = username total = await self.collection.count_documents(base_query) # Get counts by niche pipeline_niche = [ {"$match": base_query}, {"$group": {"_id": "$niche", "count": {"$sum": 1}}} ] by_niche_cursor = self.collection.aggregate(pipeline_niche) by_niche_list = await by_niche_cursor.to_list(length=100) by_niche = {item["_id"]: item["count"] for item in by_niche_list if item["_id"]} # Get counts by generation method pipeline_method = [ {"$match": base_query}, {"$group": {"_id": "$generation_method", "count": {"$sum": 1}}} ] by_method_cursor = self.collection.aggregate(pipeline_method) by_method_list = await by_method_cursor.to_list(length=100) by_method = {item["_id"]: item["count"] for item in by_method_list if item["_id"]} return { "connected": True, "total_ads": total, "by_niche": by_niche, "by_method": by_method, } except Exception as e: print(f"Failed to get stats: {e}") return {"connected": True, "error": str(e)} # ============================================================================= # USER MANAGEMENT METHODS # ============================================================================= async def create_user(self, username: str, hashed_password: str) -> Optional[str]: """Create a new user in the database.""" if self.db is None: print("Database not connected - cannot create user") return None try: users_collection = self.db["users"] # Check if user already exists existing = await users_collection.find_one({"username": username}) if existing: print(f"User '{username}' already exists") return None # Create user document user_doc = { "username": username, "hashed_password": hashed_password, "created_at": datetime.utcnow(), } result = await users_collection.insert_one(user_doc) print(f"✓ User '{username}' created successfully") return str(result.inserted_id) except Exception as e: print(f"Failed to create user: {e}") return None async def get_user(self, username: str) -> Optional[Dict[str, Any]]: """Get a user by username.""" if self.db is None: return None try: users_collection = self.db["users"] user = await users_collection.find_one({"username": username}) if user: return self._serialize_document(user) return None except Exception as e: print(f"Failed to get user: {e}") return None async def list_users(self) -> List[Dict[str, Any]]: """List all users (without passwords).""" if self.db is None: return [] try: users_collection = self.db["users"] cursor = users_collection.find({}, {"hashed_password": 0}).sort("created_at", -1) users = await cursor.to_list(length=1000) # Convert documents results = [self._serialize_document(user) for user in users] return results except Exception as e: print(f"Failed to list users: {e}") return [] async def delete_user(self, username: str) -> bool: """Delete a user by username.""" if self.db is None: return False try: users_collection = self.db["users"] result = await users_collection.delete_one({"username": username}) return result.deleted_count > 0 except Exception as e: print(f"Failed to delete user: {e}") return False # Global database service instance db_service = DatabaseService()