|
|
""" |
|
|
Database Service for PsyAdGenesis |
|
|
Handles MongoDB connection and CRUD operations |
|
|
""" |
|
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorClient |
|
|
from typing import Optional, Dict, Any, List |
|
|
from datetime import datetime |
|
|
from bson import ObjectId |
|
|
import json |
|
|
|
|
|
import sys |
|
|
import os |
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from config import settings |
|
|
|
|
|
|
|
|
class DatabaseService: |
|
|
"""Async MongoDB database service for storing ad creatives.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.client: Optional[AsyncIOMotorClient] = None |
|
|
self.db = None |
|
|
self.collection = None |
|
|
self.mongodb_url = settings.mongodb_url |
|
|
self.db_name = settings.mongodb_db_name |
|
|
|
|
|
@staticmethod |
|
|
def _datetime_to_iso_utc(dt: datetime) -> str: |
|
|
""" |
|
|
Convert a datetime object to ISO format string with UTC timezone indicator. |
|
|
Ensures the returned string always has 'Z' suffix to indicate UTC. |
|
|
""" |
|
|
iso_str = dt.isoformat() |
|
|
|
|
|
if not iso_str.endswith('Z') and '+' not in iso_str and '-' not in iso_str[-6:]: |
|
|
return iso_str + "Z" |
|
|
return iso_str |
|
|
|
|
|
def _serialize_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Serialize a MongoDB document for JSON response: |
|
|
- Convert ObjectId to string 'id' |
|
|
- Convert datetime objects to ISO strings with UTC indicator |
|
|
""" |
|
|
if not doc: |
|
|
return doc |
|
|
|
|
|
|
|
|
if "_id" in doc: |
|
|
doc["id"] = str(doc["_id"]) |
|
|
del doc["_id"] |
|
|
|
|
|
|
|
|
datetime_fields = ["created_at", "updated_at"] |
|
|
for field in datetime_fields: |
|
|
if field in doc and isinstance(doc[field], datetime): |
|
|
doc[field] = self._datetime_to_iso_utc(doc[field]) |
|
|
|
|
|
return doc |
|
|
|
|
|
async def connect(self): |
|
|
"""Create connection to MongoDB.""" |
|
|
if not self.mongodb_url: |
|
|
print("Warning: MONGODB_URL not configured. Database features disabled.") |
|
|
return False |
|
|
|
|
|
try: |
|
|
self.client = AsyncIOMotorClient(self.mongodb_url) |
|
|
|
|
|
await self.client.admin.command('ping') |
|
|
self.db = self.client[self.db_name] |
|
|
self.collection = self.db["ad_creatives"] |
|
|
|
|
|
|
|
|
await self.collection.create_index("niche") |
|
|
await self.collection.create_index("created_at") |
|
|
await self.collection.create_index([("niche", 1), ("created_at", -1)]) |
|
|
await self.collection.create_index("generation_method") |
|
|
await self.collection.create_index("username") |
|
|
await self.collection.create_index([("username", 1), ("created_at", -1)]) |
|
|
|
|
|
print(f"✓ Connected to MongoDB database: {self.db_name}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"✗ Database connection failed: {e}") |
|
|
return False |
|
|
|
|
|
async def disconnect(self): |
|
|
"""Close database connection.""" |
|
|
if self.client: |
|
|
self.client.close() |
|
|
print("Database connection closed") |
|
|
|
|
|
async def save_ad_creative( |
|
|
self, |
|
|
niche: str, |
|
|
title: str, |
|
|
headline: str, |
|
|
primary_text: str, |
|
|
description: str, |
|
|
body_story: str, |
|
|
cta: str, |
|
|
psychological_angle: str, |
|
|
why_it_works: str, |
|
|
username: str, |
|
|
image_url: Optional[str] = None, |
|
|
r2_url: Optional[str] = None, |
|
|
image_filename: Optional[str] = None, |
|
|
image_model: Optional[str] = None, |
|
|
image_seed: Optional[int] = None, |
|
|
image_prompt: Optional[str] = None, |
|
|
angle_key: Optional[str] = None, |
|
|
angle_name: Optional[str] = None, |
|
|
angle_trigger: Optional[str] = None, |
|
|
angle_category: Optional[str] = None, |
|
|
concept_key: Optional[str] = None, |
|
|
concept_name: Optional[str] = None, |
|
|
concept_structure: Optional[str] = None, |
|
|
concept_visual: Optional[str] = None, |
|
|
concept_category: Optional[str] = None, |
|
|
generation_method: str = "standard", |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
Save an ad creative to the database. |
|
|
|
|
|
Returns: |
|
|
The ID of the saved record, or None if save failed. |
|
|
""" |
|
|
if self.collection is None: |
|
|
print("Database not connected - skipping save") |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
doc = { |
|
|
"niche": niche, |
|
|
"title": title, |
|
|
"headline": headline, |
|
|
"primary_text": primary_text, |
|
|
"description": description, |
|
|
"body_story": body_story, |
|
|
"cta": cta, |
|
|
"psychological_angle": psychological_angle, |
|
|
"why_it_works": why_it_works, |
|
|
"username": username, |
|
|
"image_url": image_url, |
|
|
"r2_url": r2_url, |
|
|
"image_filename": image_filename, |
|
|
"image_model": image_model, |
|
|
"image_seed": image_seed, |
|
|
"image_prompt": image_prompt, |
|
|
"angle_key": angle_key, |
|
|
"angle_name": angle_name, |
|
|
"angle_trigger": angle_trigger, |
|
|
"angle_category": angle_category, |
|
|
"concept_key": concept_key, |
|
|
"concept_name": concept_name, |
|
|
"concept_structure": concept_structure, |
|
|
"concept_visual": concept_visual, |
|
|
"concept_category": concept_category, |
|
|
"generation_method": generation_method, |
|
|
"metadata": metadata, |
|
|
"created_at": datetime.utcnow(), |
|
|
} |
|
|
|
|
|
|
|
|
doc = {k: v for k, v in doc.items() if v is not None} |
|
|
|
|
|
result = await self.collection.insert_one(doc) |
|
|
return str(result.inserted_id) |
|
|
except Exception as e: |
|
|
print(f"Failed to save ad creative: {e}") |
|
|
return None |
|
|
|
|
|
async def get_ad_creative(self, ad_id: str, username: Optional[str] = None) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Get a single ad creative by ID. |
|
|
If username is provided, only returns the ad if it belongs to that user. |
|
|
""" |
|
|
if self.collection is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
object_id = ObjectId(ad_id) |
|
|
except: |
|
|
|
|
|
object_id = ad_id |
|
|
|
|
|
query = {"_id": object_id} |
|
|
if username: |
|
|
query["username"] = username |
|
|
|
|
|
doc = await self.collection.find_one(query) |
|
|
if doc: |
|
|
return self._serialize_document(doc) |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Failed to get ad creative: {e}") |
|
|
return None |
|
|
|
|
|
async def list_ad_creatives( |
|
|
self, |
|
|
username: str, |
|
|
niche: Optional[str] = None, |
|
|
generation_method: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0, |
|
|
) -> tuple[List[Dict[str, Any]], int]: |
|
|
""" |
|
|
List ad creatives for a specific user with optional filtering. |
|
|
Returns (results, total_count). |
|
|
""" |
|
|
if self.collection is None: |
|
|
return [], 0 |
|
|
|
|
|
try: |
|
|
query = {"username": username} |
|
|
if niche: |
|
|
query["niche"] = niche |
|
|
if generation_method: |
|
|
query["generation_method"] = generation_method |
|
|
|
|
|
|
|
|
total_count = await self.collection.count_documents(query) |
|
|
|
|
|
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit) |
|
|
docs = await cursor.to_list(length=limit) |
|
|
|
|
|
|
|
|
results = [self._serialize_document(doc) for doc in docs] |
|
|
|
|
|
return results, total_count |
|
|
except Exception as e: |
|
|
print(f"Failed to list ad creatives: {e}") |
|
|
return [], 0 |
|
|
|
|
|
async def update_ad_creative( |
|
|
self, |
|
|
ad_id: str, |
|
|
username: Optional[str] = None, |
|
|
image_url: Optional[str] = None, |
|
|
image_filename: Optional[str] = None, |
|
|
image_model: Optional[str] = None, |
|
|
image_prompt: Optional[str] = None, |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
**kwargs |
|
|
) -> bool: |
|
|
""" |
|
|
Update an ad creative by ID. |
|
|
If username is provided, only updates if the ad belongs to that user. |
|
|
|
|
|
Args: |
|
|
ad_id: ID of the ad to update |
|
|
username: Optional username to verify ownership |
|
|
image_url: New image URL |
|
|
image_filename: New image filename |
|
|
image_model: New image model |
|
|
image_prompt: New image prompt |
|
|
metadata: Metadata dict (will be merged with existing metadata) |
|
|
**kwargs: Additional fields to update |
|
|
|
|
|
Returns: |
|
|
True if update was successful, False otherwise |
|
|
""" |
|
|
if self.collection is None: |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
object_id = ObjectId(ad_id) |
|
|
except: |
|
|
|
|
|
object_id = ad_id |
|
|
|
|
|
query = {"_id": object_id} |
|
|
if username: |
|
|
query["username"] = username |
|
|
|
|
|
|
|
|
update_doc = {"updated_at": datetime.utcnow()} |
|
|
|
|
|
if image_url is not None: |
|
|
update_doc["image_url"] = image_url |
|
|
if image_filename is not None: |
|
|
update_doc["image_filename"] = image_filename |
|
|
if image_model is not None: |
|
|
update_doc["image_model"] = image_model |
|
|
if image_prompt is not None: |
|
|
update_doc["image_prompt"] = image_prompt |
|
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if value is not None: |
|
|
update_doc[key] = value |
|
|
|
|
|
|
|
|
if metadata is not None: |
|
|
|
|
|
existing_ad = await self.collection.find_one(query) |
|
|
if existing_ad: |
|
|
existing_metadata = existing_ad.get("metadata") or {} |
|
|
|
|
|
merged_metadata = {**existing_metadata, **metadata} |
|
|
update_doc["metadata"] = merged_metadata |
|
|
print(f"Debug: Merging metadata. Existing: {existing_metadata}, New: {metadata}, Merged: {merged_metadata}") |
|
|
else: |
|
|
update_doc["metadata"] = metadata |
|
|
print(f"Debug: Setting new metadata (no existing ad found): {metadata}") |
|
|
|
|
|
result = await self.collection.update_one( |
|
|
query, |
|
|
{"$set": update_doc} |
|
|
) |
|
|
return result.modified_count > 0 |
|
|
except Exception as e: |
|
|
print(f"Failed to update ad creative: {e}") |
|
|
return False |
|
|
|
|
|
async def delete_ad_creative(self, ad_id: str, username: Optional[str] = None) -> bool: |
|
|
""" |
|
|
Delete an ad creative by ID. |
|
|
If username is provided, only deletes if the ad belongs to that user. |
|
|
""" |
|
|
if self.collection is None: |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
object_id = ObjectId(ad_id) |
|
|
except: |
|
|
|
|
|
object_id = ad_id |
|
|
|
|
|
query = {"_id": object_id} |
|
|
if username: |
|
|
query["username"] = username |
|
|
|
|
|
result = await self.collection.delete_one(query) |
|
|
return result.deleted_count > 0 |
|
|
except Exception as e: |
|
|
print(f"Failed to delete ad creative: {e}") |
|
|
return False |
|
|
|
|
|
async def get_stats(self, username: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Get statistics about stored ad creatives. |
|
|
If username is provided, only returns stats for that user's ads. |
|
|
""" |
|
|
if self.collection is None: |
|
|
return {"connected": False} |
|
|
|
|
|
try: |
|
|
|
|
|
base_query = {} |
|
|
if username: |
|
|
base_query["username"] = username |
|
|
|
|
|
total = await self.collection.count_documents(base_query) |
|
|
|
|
|
|
|
|
pipeline_niche = [ |
|
|
{"$match": base_query}, |
|
|
{"$group": {"_id": "$niche", "count": {"$sum": 1}}} |
|
|
] |
|
|
by_niche_cursor = self.collection.aggregate(pipeline_niche) |
|
|
by_niche_list = await by_niche_cursor.to_list(length=100) |
|
|
by_niche = {item["_id"]: item["count"] for item in by_niche_list if item["_id"]} |
|
|
|
|
|
|
|
|
pipeline_method = [ |
|
|
{"$match": base_query}, |
|
|
{"$group": {"_id": "$generation_method", "count": {"$sum": 1}}} |
|
|
] |
|
|
by_method_cursor = self.collection.aggregate(pipeline_method) |
|
|
by_method_list = await by_method_cursor.to_list(length=100) |
|
|
by_method = {item["_id"]: item["count"] for item in by_method_list if item["_id"]} |
|
|
|
|
|
return { |
|
|
"connected": True, |
|
|
"total_ads": total, |
|
|
"by_niche": by_niche, |
|
|
"by_method": by_method, |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Failed to get stats: {e}") |
|
|
return {"connected": True, "error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_user(self, username: str, hashed_password: str) -> Optional[str]: |
|
|
"""Create a new user in the database.""" |
|
|
if self.db is None: |
|
|
print("Database not connected - cannot create user") |
|
|
return None |
|
|
|
|
|
try: |
|
|
users_collection = self.db["users"] |
|
|
|
|
|
|
|
|
existing = await users_collection.find_one({"username": username}) |
|
|
if existing: |
|
|
print(f"User '{username}' already exists") |
|
|
return None |
|
|
|
|
|
|
|
|
user_doc = { |
|
|
"username": username, |
|
|
"hashed_password": hashed_password, |
|
|
"created_at": datetime.utcnow(), |
|
|
} |
|
|
|
|
|
result = await users_collection.insert_one(user_doc) |
|
|
print(f"✓ User '{username}' created successfully") |
|
|
return str(result.inserted_id) |
|
|
except Exception as e: |
|
|
print(f"Failed to create user: {e}") |
|
|
return None |
|
|
|
|
|
async def get_user(self, username: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get a user by username.""" |
|
|
if self.db is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
users_collection = self.db["users"] |
|
|
user = await users_collection.find_one({"username": username}) |
|
|
if user: |
|
|
return self._serialize_document(user) |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Failed to get user: {e}") |
|
|
return None |
|
|
|
|
|
async def list_users(self) -> List[Dict[str, Any]]: |
|
|
"""List all users (without passwords).""" |
|
|
if self.db is None: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
users_collection = self.db["users"] |
|
|
cursor = users_collection.find({}, {"hashed_password": 0}).sort("created_at", -1) |
|
|
users = await cursor.to_list(length=1000) |
|
|
|
|
|
|
|
|
results = [self._serialize_document(user) for user in users] |
|
|
|
|
|
return results |
|
|
except Exception as e: |
|
|
print(f"Failed to list users: {e}") |
|
|
return [] |
|
|
|
|
|
async def delete_user(self, username: str) -> bool: |
|
|
"""Delete a user by username.""" |
|
|
if self.db is None: |
|
|
return False |
|
|
|
|
|
try: |
|
|
users_collection = self.db["users"] |
|
|
result = await users_collection.delete_one({"username": username}) |
|
|
return result.deleted_count > 0 |
|
|
except Exception as e: |
|
|
print(f"Failed to delete user: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
db_service = DatabaseService() |
|
|
|