sushilideaclan01's picture
removed the containers and made some optimizations
505ff55
"""
Database Service for PsyAdGenesis
Handles MongoDB connection and CRUD operations
"""
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Optional, Dict, Any, List
from datetime import datetime
from bson import ObjectId
import json
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import settings
class DatabaseService:
"""Async MongoDB database service for storing ad creatives."""
def __init__(self):
self.client: Optional[AsyncIOMotorClient] = None
self.db = None
self.collection = None
self.mongodb_url = settings.mongodb_url
self.db_name = settings.mongodb_db_name
@staticmethod
def _datetime_to_iso_utc(dt: datetime) -> str:
"""
Convert a datetime object to ISO format string with UTC timezone indicator.
Ensures the returned string always has 'Z' suffix to indicate UTC.
"""
iso_str = dt.isoformat()
# Add 'Z' suffix if not present to indicate UTC
if not iso_str.endswith('Z') and '+' not in iso_str and '-' not in iso_str[-6:]:
return iso_str + "Z"
return iso_str
def _serialize_document(self, doc: Dict[str, Any]) -> Dict[str, Any]:
"""
Serialize a MongoDB document for JSON response:
- Convert ObjectId to string 'id'
- Convert datetime objects to ISO strings with UTC indicator
"""
if not doc:
return doc
# Convert ObjectId to string
if "_id" in doc:
doc["id"] = str(doc["_id"])
del doc["_id"]
# Convert datetime fields to ISO strings with UTC indicator
datetime_fields = ["created_at", "updated_at"]
for field in datetime_fields:
if field in doc and isinstance(doc[field], datetime):
doc[field] = self._datetime_to_iso_utc(doc[field])
return doc
async def connect(self):
"""Create connection to MongoDB."""
if not self.mongodb_url:
print("Warning: MONGODB_URL not configured. Database features disabled.")
return False
try:
self.client = AsyncIOMotorClient(self.mongodb_url)
# Test connection
await self.client.admin.command('ping')
self.db = self.client[self.db_name]
self.collection = self.db["ad_creatives"]
# Create indexes for better query performance
await self.collection.create_index("niche")
await self.collection.create_index("created_at")
await self.collection.create_index([("niche", 1), ("created_at", -1)])
await self.collection.create_index("generation_method")
await self.collection.create_index("username") # User-specific index
await self.collection.create_index([("username", 1), ("created_at", -1)]) # User + date index
print(f"✓ Connected to MongoDB database: {self.db_name}")
return True
except Exception as e:
print(f"✗ Database connection failed: {e}")
return False
async def disconnect(self):
"""Close database connection."""
if self.client:
self.client.close()
print("Database connection closed")
async def save_ad_creative(
self,
niche: str,
title: str,
headline: str,
primary_text: str,
description: str,
body_story: str,
cta: str,
psychological_angle: str,
why_it_works: str,
username: str, # Required: username of the user creating the ad
image_url: Optional[str] = None,
r2_url: Optional[str] = None,
image_filename: Optional[str] = None,
image_model: Optional[str] = None,
image_seed: Optional[int] = None,
image_prompt: Optional[str] = None,
angle_key: Optional[str] = None,
angle_name: Optional[str] = None,
angle_trigger: Optional[str] = None,
angle_category: Optional[str] = None,
concept_key: Optional[str] = None,
concept_name: Optional[str] = None,
concept_structure: Optional[str] = None,
concept_visual: Optional[str] = None,
concept_category: Optional[str] = None,
generation_method: str = "standard",
metadata: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""
Save an ad creative to the database.
Returns:
The ID of the saved record, or None if save failed.
"""
if self.collection is None:
print("Database not connected - skipping save")
return None
try:
# Prepare document
doc = {
"niche": niche,
"title": title,
"headline": headline,
"primary_text": primary_text,
"description": description,
"body_story": body_story,
"cta": cta,
"psychological_angle": psychological_angle,
"why_it_works": why_it_works,
"username": username, # Store username for filtering
"image_url": image_url,
"r2_url": r2_url,
"image_filename": image_filename,
"image_model": image_model,
"image_seed": image_seed,
"image_prompt": image_prompt,
"angle_key": angle_key,
"angle_name": angle_name,
"angle_trigger": angle_trigger,
"angle_category": angle_category,
"concept_key": concept_key,
"concept_name": concept_name,
"concept_structure": concept_structure,
"concept_visual": concept_visual,
"concept_category": concept_category,
"generation_method": generation_method,
"metadata": metadata,
"created_at": datetime.utcnow(),
}
# Remove None values to keep documents clean
doc = {k: v for k, v in doc.items() if v is not None}
result = await self.collection.insert_one(doc)
return str(result.inserted_id)
except Exception as e:
print(f"Failed to save ad creative: {e}")
return None
async def get_ad_creative(self, ad_id: str, username: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""
Get a single ad creative by ID.
If username is provided, only returns the ad if it belongs to that user.
"""
if self.collection is None:
return None
try:
# Try to convert to ObjectId if it's a valid ObjectId string
try:
object_id = ObjectId(ad_id)
except:
# If not a valid ObjectId, try as string
object_id = ad_id
query = {"_id": object_id}
if username:
query["username"] = username
doc = await self.collection.find_one(query)
if doc:
return self._serialize_document(doc)
return None
except Exception as e:
print(f"Failed to get ad creative: {e}")
return None
async def list_ad_creatives(
self,
username: str, # Required: filter by username
niche: Optional[str] = None,
generation_method: Optional[str] = None,
limit: int = 50,
offset: int = 0,
) -> tuple[List[Dict[str, Any]], int]:
"""
List ad creatives for a specific user with optional filtering.
Returns (results, total_count).
"""
if self.collection is None:
return [], 0
try:
query = {"username": username} # Always filter by username
if niche:
query["niche"] = niche
if generation_method:
query["generation_method"] = generation_method
# Get total count before pagination
total_count = await self.collection.count_documents(query)
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
docs = await cursor.to_list(length=limit)
# Convert documents to dict format
results = [self._serialize_document(doc) for doc in docs]
return results, total_count
except Exception as e:
print(f"Failed to list ad creatives: {e}")
return [], 0
async def update_ad_creative(
self,
ad_id: str,
username: Optional[str] = None,
image_url: Optional[str] = None,
image_filename: Optional[str] = None,
image_model: Optional[str] = None,
image_prompt: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs
) -> bool:
"""
Update an ad creative by ID.
If username is provided, only updates if the ad belongs to that user.
Args:
ad_id: ID of the ad to update
username: Optional username to verify ownership
image_url: New image URL
image_filename: New image filename
image_model: New image model
image_prompt: New image prompt
metadata: Metadata dict (will be merged with existing metadata)
**kwargs: Additional fields to update
Returns:
True if update was successful, False otherwise
"""
if self.collection is None:
return False
try:
# Try to convert to ObjectId if it's a valid ObjectId string
try:
object_id = ObjectId(ad_id)
except:
# If not a valid ObjectId, try as string
object_id = ad_id
query = {"_id": object_id}
if username:
query["username"] = username
# Build update document
update_doc = {"updated_at": datetime.utcnow()}
if image_url is not None:
update_doc["image_url"] = image_url
if image_filename is not None:
update_doc["image_filename"] = image_filename
if image_model is not None:
update_doc["image_model"] = image_model
if image_prompt is not None:
update_doc["image_prompt"] = image_prompt
# Add any additional fields from kwargs
for key, value in kwargs.items():
if value is not None:
update_doc[key] = value
# Handle metadata merge
if metadata is not None:
# Get existing ad to merge metadata
existing_ad = await self.collection.find_one(query)
if existing_ad:
existing_metadata = existing_ad.get("metadata") or {}
# Merge metadata (new values override old ones)
merged_metadata = {**existing_metadata, **metadata}
update_doc["metadata"] = merged_metadata
print(f"Debug: Merging metadata. Existing: {existing_metadata}, New: {metadata}, Merged: {merged_metadata}")
else:
update_doc["metadata"] = metadata
print(f"Debug: Setting new metadata (no existing ad found): {metadata}")
result = await self.collection.update_one(
query,
{"$set": update_doc}
)
return result.modified_count > 0
except Exception as e:
print(f"Failed to update ad creative: {e}")
return False
async def delete_ad_creative(self, ad_id: str, username: Optional[str] = None) -> bool:
"""
Delete an ad creative by ID.
If username is provided, only deletes if the ad belongs to that user.
"""
if self.collection is None:
return False
try:
# Try to convert to ObjectId if it's a valid ObjectId string
try:
object_id = ObjectId(ad_id)
except:
# If not a valid ObjectId, try as string
object_id = ad_id
query = {"_id": object_id}
if username:
query["username"] = username
result = await self.collection.delete_one(query)
return result.deleted_count > 0
except Exception as e:
print(f"Failed to delete ad creative: {e}")
return False
async def get_stats(self, username: Optional[str] = None) -> Dict[str, Any]:
"""
Get statistics about stored ad creatives.
If username is provided, only returns stats for that user's ads.
"""
if self.collection is None:
return {"connected": False}
try:
# Build base query
base_query = {}
if username:
base_query["username"] = username
total = await self.collection.count_documents(base_query)
# Get counts by niche
pipeline_niche = [
{"$match": base_query},
{"$group": {"_id": "$niche", "count": {"$sum": 1}}}
]
by_niche_cursor = self.collection.aggregate(pipeline_niche)
by_niche_list = await by_niche_cursor.to_list(length=100)
by_niche = {item["_id"]: item["count"] for item in by_niche_list if item["_id"]}
# Get counts by generation method
pipeline_method = [
{"$match": base_query},
{"$group": {"_id": "$generation_method", "count": {"$sum": 1}}}
]
by_method_cursor = self.collection.aggregate(pipeline_method)
by_method_list = await by_method_cursor.to_list(length=100)
by_method = {item["_id"]: item["count"] for item in by_method_list if item["_id"]}
return {
"connected": True,
"total_ads": total,
"by_niche": by_niche,
"by_method": by_method,
}
except Exception as e:
print(f"Failed to get stats: {e}")
return {"connected": True, "error": str(e)}
# =============================================================================
# USER MANAGEMENT METHODS
# =============================================================================
async def create_user(self, username: str, hashed_password: str) -> Optional[str]:
"""Create a new user in the database."""
if self.db is None:
print("Database not connected - cannot create user")
return None
try:
users_collection = self.db["users"]
# Check if user already exists
existing = await users_collection.find_one({"username": username})
if existing:
print(f"User '{username}' already exists")
return None
# Create user document
user_doc = {
"username": username,
"hashed_password": hashed_password,
"created_at": datetime.utcnow(),
}
result = await users_collection.insert_one(user_doc)
print(f"✓ User '{username}' created successfully")
return str(result.inserted_id)
except Exception as e:
print(f"Failed to create user: {e}")
return None
async def get_user(self, username: str) -> Optional[Dict[str, Any]]:
"""Get a user by username."""
if self.db is None:
return None
try:
users_collection = self.db["users"]
user = await users_collection.find_one({"username": username})
if user:
return self._serialize_document(user)
return None
except Exception as e:
print(f"Failed to get user: {e}")
return None
async def list_users(self) -> List[Dict[str, Any]]:
"""List all users (without passwords)."""
if self.db is None:
return []
try:
users_collection = self.db["users"]
cursor = users_collection.find({}, {"hashed_password": 0}).sort("created_at", -1)
users = await cursor.to_list(length=1000)
# Convert documents
results = [self._serialize_document(user) for user in users]
return results
except Exception as e:
print(f"Failed to list users: {e}")
return []
async def delete_user(self, username: str) -> bool:
"""Delete a user by username."""
if self.db is None:
return False
try:
users_collection = self.db["users"]
result = await users_collection.delete_one({"username": username})
return result.deleted_count > 0
except Exception as e:
print(f"Failed to delete user: {e}")
return False
# Global database service instance
db_service = DatabaseService()