"""Temporal workflow definitions for the AI Media OS.""" import re from datetime import timedelta from temporalio import workflow, activity from temporalio.common import RetryPolicy from dataclasses import dataclass from typing import Optional, Dict, Any from loguru import logger from sqlalchemy import select from src.utils.database import AsyncSessionLocal from src.models.database import Trend, Post as PostModel, ApiUsage from datetime import datetime import httpx import os import cloudinary import cloudinary.uploader from src.config import get_settings from openai import OpenAI def _strip_markdown(text: str) -> str: """Remove markdown so captions look natural on Instagram.""" text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) text = re.sub(r'\*(.+?)\*', r'\1', text) text = re.sub(r'__(.+?)__', r'\1', text) text = re.sub(r'_(.+?)_', r'\1', text) text = re.sub(r'#+\s*', '', text) text = re.sub(r'`(.+?)`', r'\1', text) text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text) text = re.sub(r'^\s*[-*]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'\n{3,}', '\n\n', text) return text.strip() @dataclass class TrendData: """Trend input data.""" trend_id: int topic: str source: str score: float raw_data: Optional[Dict[str, Any]] = None @dataclass class PostData: """Generated post data.""" post_id: int content: str image_url: Optional[str] = None platform: str = "instagram" @dataclass class PublishResult: """Result of publishing.""" success: bool platform_post_id: Optional[str] = None error: Optional[str] = None @dataclass class GeneratePostInput: trend: TrendData agent_version_id: int @dataclass class GenerateImagesInput: topic: str count: int = 1 image_prompt: str = None @dataclass class ModerateContentInput: post: PostData images: Dict[str, str] @dataclass class StoreMediaInput: images: Dict[str, str] post_id: int @dataclass class SavePostDraftInput: post: PostData trend_id: int agent_version_id: int images: Dict[str, str] @dataclass class PublishToPlatformInput: post: PostData platform: str = "instagram" @dataclass class RecordMetricsInput: post_id: int status: str cost_usd: float token_count: int platform_post_id: Optional[str] = None # Activity definitions @activity.defn async def fetch_trend_details(trend_id: int) -> TrendData: """Fetch trend details from database.""" logger.info(f"Fetching trend details for trend_id={trend_id}") async with AsyncSessionLocal() as session: result = await session.execute(select(Trend).where(Trend.id == trend_id)) trend = result.scalar_one_or_none() if not trend: raise Exception(f"Trend with ID {trend_id} not found") return TrendData( trend_id=trend.id, topic=trend.topic, source=trend.source, score=trend.score or 0.0, raw_data=trend.raw_data, ) @activity.defn async def generate_post_content(input: GeneratePostInput) -> PostData: """Generate post content using LLM agent.""" from src.agents.post_generator import generate_post_with_agent logger.info(f"Generating content for trend: {input.trend.topic}") trend_dict = { "trend_id": input.trend.trend_id, "topic": input.trend.topic, "source": input.trend.source, "score": input.trend.score, "raw_data": input.trend.raw_data, } result = await generate_post_with_agent(trend_dict, input.agent_version_id) if not result.get("success"): raise Exception(f"Agent failed to generate content: {result.get('error')}") return PostData( post_id=0, # Placeholder until saved in DB content=result.get("post_content", "Fallback content if generation failed"), platform="instagram", ) @activity.defn async def generate_images(input: GenerateImagesInput) -> Dict[str, str]: """Generate images for the post using DALL-E 3.""" logger.info(f"Generating image for topic: {input.topic}") settings = get_settings() # Check if we are using OpenRouter or standard OpenAI is_openrouter = "openrouter.ai" in settings.openai_api_base try: # Use the LLM-written image prompt if available, otherwise fall back to topic base_prompt = input.image_prompt if getattr(input, "image_prompt", None) else ( f"A photorealistic scene representing: {input.topic}. " f"Cinematic lighting, high quality photography." ) prompt = f"{base_prompt} No text, no words, no letters, no typography, no watermarks, no captions." negative = "text, words, letters, typography, watermark, caption, title, logo, grid, mesh, overlay, graphic design" if is_openrouter: logger.info("Using Pollinations.AI (free, no key required) for image generation...") import urllib.parse encoded_prompt = urllib.parse.quote(prompt) encoded_negative = urllib.parse.quote(negative) image_url = ( f"https://image.pollinations.ai/prompt/{encoded_prompt}" f"?width=1024&height=1024&nologo=true&model=flux&negative={encoded_negative}" ) # Verify the URL is reachable async with httpx.AsyncClient() as client: response = await client.get(image_url, timeout=120, follow_redirects=True) if response.status_code != 200: logger.error(f"Pollinations.AI Error ({response.status_code})") response.raise_for_status() logger.info(f"Pollinations.AI image URL ready: {image_url}") else: logger.info("Using standard OpenAI DALL-E 3...") client = OpenAI( api_key=settings.openai_api_key, base_url=settings.openai_api_base if settings.openai_api_base else None ) response = client.images.generate( model="dall-e-3", prompt=prompt, size="1024x1024", quality="standard", n=1, ) image_url = response.data[0].url if not image_url: raise Exception("Failed to extract image URL from AI response") print(f"DEBUG: FINAL IMAGE RESULT: {image_url}") logger.info(f"Successfully generated AI image: {image_url}") return {"image_1": image_url} except Exception as e: logger.error(f"IMAGE GENERATION CRITICAL FAILURE: {str(e)}") raise e @activity.defn async def moderate_content(input: ModerateContentInput) -> Dict[str, Any]: """Check content and images for NSFW/moderation issues.""" logger.info(f"Running moderation checks on post {input.post.post_id}") # Use OpenAI Moderation API if available api_key = os.getenv("OPENAI_API_KEY") if api_key and not api_key.startswith("sk-xxx"): try: async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/moderations", headers={"Authorization": f"Bearer {api_key}"}, json={"input": input.post.content}, ) response.raise_for_status() result = response.json() flagged = result["results"][0]["flagged"] return { "approved": not flagged, "nsfw_scores": result["results"][0]["categories"], "moderation_notes": "OpenAI automated check", } except Exception as e: logger.error(f"Moderation API failed: {e}") # Fallback to manual approval or simple keyword check # Simple keyword check as fallback blocked_words = ["nsfw", "violence", "hate"] content_lower = input.post.content.lower() for word in blocked_words: if word in content_lower: return {"approved": False, "moderation_notes": f"Flagged by keyword: {word}"} return { "approved": True, "nsfw_scores": {}, "moderation_notes": "Simple keyword check (Fallback)", } @activity.defn async def store_media_to_cdn(input: StoreMediaInput) -> Dict[str, str]: """Upload images to Cloudinary CDN.""" logger.info(f"Uploading media for post {input.post_id} to Cloudinary") settings = get_settings() # Configure cloudinary cloudinary.config( cloud_name=settings.cloudinary_cloud_name, api_key=settings.cloudinary_api_key, api_secret=settings.cloudinary_api_secret, secure=True, ) cdn_urls = {} for key, url in input.images.items(): try: # Upload to Cloudinary # Note: Cloudinary's SDK is blocking, so we run it in a thread if needed, # but for simplicity in this activity we'll call it directly. # In production, use asyncio loop.run_in_executor logger.info(f"Uploading {url} to Cloudinary...") # If the URL is a local path, Cloudinary handles it. # If it's a dummy URL like example.com, it might fail. if "example.com" in url: logger.warning(f"Skipping dummy URL upload: {url}") cdn_urls[key] = url continue response = cloudinary.uploader.upload(url, folder=f"ai_media_os/post_{input.post_id}", resource_type="auto") cdn_urls[key] = response.get("secure_url") logger.info(f"Successfully uploaded to Cloudinary: {cdn_urls[key]}") except Exception as e: logger.error(f"Failed to upload to Cloudinary: {e}") # Fallback to original URL cdn_urls[key] = url return cdn_urls @activity.defn async def save_post_draft(input: SavePostDraftInput) -> PostData: """Save post as draft in database.""" logger.info(f"Saving post draft for trend {input.trend_id}") async with AsyncSessionLocal() as session: new_post = PostModel( trend_id=input.trend_id, agent_version_id=input.agent_version_id, content=_strip_markdown(input.post.content), image_url=input.images.get("image_1"), platform=input.post.platform, status="draft", approval_status="pending", ) session.add(new_post) await session.commit() await session.refresh(new_post) return PostData( post_id=new_post.id, content=new_post.content, image_url=new_post.image_url, platform=new_post.platform, ) @activity.defn async def publish_to_platform(input: PublishToPlatformInput) -> PublishResult: """Publish post to social platform.""" logger.info(f"Publishing post {input.post.post_id} to {input.platform}") from src.services.social_media import SocialMediaPublisher publisher = SocialMediaPublisher() clean_text = _strip_markdown(input.post.content) result = await publisher.publish_post( platform=input.platform, text=clean_text, image_urls=[input.post.image_url] if input.post.image_url else [], ) if result["success"]: return PublishResult( success=True, platform_post_id=result.get("platform_post_id", "unknown"), error=None, ) else: return PublishResult( success=False, platform_post_id=None, error=result.get("error", "Unknown publishing error"), ) @activity.defn async def record_execution_metrics(input: RecordMetricsInput) -> None: """Record execution metrics and costs.""" logger.info( f"Recording metrics for post {input.post_id}: " f"status={input.status}, cost=${input.cost_usd}, tokens={input.token_count}" ) async with AsyncSessionLocal() as session: # Update post with metrics result = await session.execute(select(PostModel).where(PostModel.id == input.post_id)) post = result.scalar_one_or_none() if post: post.token_count = (post.token_count or 0) + input.token_count post.generation_cost = (post.generation_cost or 0.0) + input.cost_usd post.status = input.status if input.status == "published": post.published_at = datetime.utcnow() # Track aggregate API usage period_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) usage_result = await session.execute( select(ApiUsage).where(ApiUsage.service == "openai", ApiUsage.period_start == period_start) ) usage = usage_result.scalar_one_or_none() if not usage: usage = ApiUsage( service="openai", period_start=period_start, period_end=period_start.replace(hour=23, minute=59, second=59), request_count=1, token_count=input.token_count, cost_usd=input.cost_usd, ) session.add(usage) else: usage.request_count += 1 usage.token_count += input.token_count usage.cost_usd += input.cost_usd await session.commit() # Workflow definition @workflow.defn class TrendToPostPublishWorkflow: """ Main workflow: Trend → Content Generation → Moderation → Publishing. Flow: 1. Fetch trend details 2. Generate post content (LLM) 3. Generate images 4. Moderate content 5. Upload to CDN 6. Save draft in DB 7. Wait for moderator approval (signal) 8. Publish to platform 9. Record metrics """ def __init__(self): self._is_approved: Optional[bool] = None @workflow.signal def approve_post(self, approved: bool) -> None: self._is_approved = approved @workflow.run async def run( self, trend_id: int, agent_version_id: int, platform: str = "instagram", ) -> Dict[str, Any]: """Execute the workflow.""" logger.info(f"Starting workflow for trend_id={trend_id}") try: # Step 1: Fetch trend trend = await workflow.execute_activity( fetch_trend_details, trend_id, start_to_close_timeout=timedelta(seconds=30), ) # Step 2: Generate content post = await workflow.execute_activity( generate_post_content, GeneratePostInput(trend=trend, agent_version_id=agent_version_id), start_to_close_timeout=timedelta(minutes=5), retry_policy=RetryPolicy( maximum_attempts=5, initial_interval=timedelta(seconds=2), backoff_coefficient=2.0, ), ) # Step 3: Generate images — use LLM-crafted image_prompt if available images = await workflow.execute_activity( generate_images, GenerateImagesInput( topic=trend.topic, count=1, image_prompt=post.get("image_prompt") if isinstance(post, dict) else None, ), start_to_close_timeout=timedelta(minutes=10), retry_policy=RetryPolicy( maximum_attempts=2, initial_interval=timedelta(seconds=5), ), ) # Step 4: Moderation moderation = await workflow.execute_activity( moderate_content, ModerateContentInput(post=post, images=images), start_to_close_timeout=timedelta(seconds=60), ) if not moderation.get("approved", False): logger.warning(f"Content moderation rejected for trend {trend_id}") return { "status": "rejected", "reason": "Moderation failed", "post_id": None, } # Step 5: Save draft first (so we get a real post_id) post = await workflow.execute_activity( save_post_draft, SavePostDraftInput( post=post, trend_id=trend_id, agent_version_id=agent_version_id, images=images, ), start_to_close_timeout=timedelta(seconds=30), ) # Step 6: Upload to CDN (now we have a real post_id) cdn_images = await workflow.execute_activity( store_media_to_cdn, StoreMediaInput(images=images, post_id=post.post_id), start_to_close_timeout=timedelta(seconds=60), ) # Update post image_url to CDN url post = PostData( post_id=post.post_id, content=post.content, image_url=cdn_images.get("image_1", post.image_url), platform=post.platform, ) # Step 7: Approval — auto or human-in-the-loop settings = get_settings() if settings.auto_approve: logger.info(f"AUTO_APPROVE enabled — skipping human review for post {post.post_id}") self._is_approved = True else: logger.info(f"Waiting for moderator approval of post {post.post_id}") await workflow.wait_condition( lambda: self._is_approved is not None, timeout=timedelta(hours=24), ) if not self._is_approved: logger.info(f"Post {post.post_id} was rejected by moderator") return { "status": "rejected_by_moderator", "post_id": post.post_id, } # Step 8: Publish result = await workflow.execute_activity( publish_to_platform, PublishToPlatformInput(post=post, platform=platform), start_to_close_timeout=timedelta(minutes=2), retry_policy=RetryPolicy( maximum_attempts=3, initial_interval=timedelta(seconds=5), ), ) if not result.success: logger.error(f"Publishing failed for post {post.post_id}: {result.error}") return { "status": "publish_failed", "post_id": post.post_id, "error": result.error, } # Step 9: Record metrics await workflow.execute_activity( record_execution_metrics, RecordMetricsInput( post_id=post.post_id, status="published", platform_post_id=result.platform_post_id, cost_usd=0.50, token_count=250, ), start_to_close_timeout=timedelta(seconds=30), ) logger.info(f"Workflow completed successfully for post {post.post_id}") return { "status": "success", "post_id": post.post_id, "platform_post_id": result.platform_post_id, } except Exception as e: logger.error(f"Workflow error: {e}") return { "status": "failed", "error": str(e), }