Spaces:
Running
Running
| """Temporal workflow definitions for the AI Media OS.""" | |
| import re | |
| from datetime import timedelta | |
| from temporalio import workflow, activity | |
| from temporalio.common import RetryPolicy | |
| from dataclasses import dataclass | |
| from typing import Optional, Dict, Any | |
| from loguru import logger | |
| from sqlalchemy import select | |
| from src.utils.database import AsyncSessionLocal | |
| from src.models.database import Trend, Post as PostModel, ApiUsage | |
| from datetime import datetime | |
| import httpx | |
| import os | |
| import cloudinary | |
| import cloudinary.uploader | |
| from src.config import get_settings | |
| from openai import OpenAI | |
| def _strip_markdown(text: str) -> str: | |
| """Remove markdown so captions look natural on Instagram.""" | |
| text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) | |
| text = re.sub(r'\*(.+?)\*', r'\1', text) | |
| text = re.sub(r'__(.+?)__', r'\1', text) | |
| text = re.sub(r'_(.+?)_', r'\1', text) | |
| text = re.sub(r'#+\s*', '', text) | |
| text = re.sub(r'`(.+?)`', r'\1', text) | |
| text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text) | |
| text = re.sub(r'^\s*[-*]\s+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'\n{3,}', '\n\n', text) | |
| return text.strip() | |
| class TrendData: | |
| """Trend input data.""" | |
| trend_id: int | |
| topic: str | |
| source: str | |
| score: float | |
| raw_data: Optional[Dict[str, Any]] = None | |
| class PostData: | |
| """Generated post data.""" | |
| post_id: int | |
| content: str | |
| image_url: Optional[str] = None | |
| platform: str = "instagram" | |
| class PublishResult: | |
| """Result of publishing.""" | |
| success: bool | |
| platform_post_id: Optional[str] = None | |
| error: Optional[str] = None | |
| class GeneratePostInput: | |
| trend: TrendData | |
| agent_version_id: int | |
| class GenerateImagesInput: | |
| topic: str | |
| count: int = 1 | |
| image_prompt: str = None | |
| class ModerateContentInput: | |
| post: PostData | |
| images: Dict[str, str] | |
| class StoreMediaInput: | |
| images: Dict[str, str] | |
| post_id: int | |
| class SavePostDraftInput: | |
| post: PostData | |
| trend_id: int | |
| agent_version_id: int | |
| images: Dict[str, str] | |
| class PublishToPlatformInput: | |
| post: PostData | |
| platform: str = "instagram" | |
| class RecordMetricsInput: | |
| post_id: int | |
| status: str | |
| cost_usd: float | |
| token_count: int | |
| platform_post_id: Optional[str] = None | |
| # Activity definitions | |
| async def fetch_trend_details(trend_id: int) -> TrendData: | |
| """Fetch trend details from database.""" | |
| logger.info(f"Fetching trend details for trend_id={trend_id}") | |
| async with AsyncSessionLocal() as session: | |
| result = await session.execute(select(Trend).where(Trend.id == trend_id)) | |
| trend = result.scalar_one_or_none() | |
| if not trend: | |
| raise Exception(f"Trend with ID {trend_id} not found") | |
| return TrendData( | |
| trend_id=trend.id, | |
| topic=trend.topic, | |
| source=trend.source, | |
| score=trend.score or 0.0, | |
| raw_data=trend.raw_data, | |
| ) | |
| async def generate_post_content(input: GeneratePostInput) -> PostData: | |
| """Generate post content using LLM agent.""" | |
| from src.agents.post_generator import generate_post_with_agent | |
| logger.info(f"Generating content for trend: {input.trend.topic}") | |
| trend_dict = { | |
| "trend_id": input.trend.trend_id, | |
| "topic": input.trend.topic, | |
| "source": input.trend.source, | |
| "score": input.trend.score, | |
| "raw_data": input.trend.raw_data, | |
| } | |
| result = await generate_post_with_agent(trend_dict, input.agent_version_id) | |
| if not result.get("success"): | |
| raise Exception(f"Agent failed to generate content: {result.get('error')}") | |
| return PostData( | |
| post_id=0, # Placeholder until saved in DB | |
| content=result.get("post_content", "Fallback content if generation failed"), | |
| platform="instagram", | |
| ) | |
| async def generate_images(input: GenerateImagesInput) -> Dict[str, str]: | |
| """Generate images for the post using DALL-E 3.""" | |
| logger.info(f"Generating image for topic: {input.topic}") | |
| settings = get_settings() | |
| # Check if we are using OpenRouter or standard OpenAI | |
| is_openrouter = "openrouter.ai" in settings.openai_api_base | |
| try: | |
| # Use the LLM-written image prompt if available, otherwise fall back to topic | |
| base_prompt = input.image_prompt if getattr(input, "image_prompt", None) else ( | |
| f"A photorealistic scene representing: {input.topic}. " | |
| f"Cinematic lighting, high quality photography." | |
| ) | |
| prompt = f"{base_prompt} No text, no words, no letters, no typography, no watermarks, no captions." | |
| negative = "text, words, letters, typography, watermark, caption, title, logo, grid, mesh, overlay, graphic design" | |
| if is_openrouter: | |
| logger.info("Using Pollinations.AI (free, no key required) for image generation...") | |
| import urllib.parse | |
| encoded_prompt = urllib.parse.quote(prompt) | |
| encoded_negative = urllib.parse.quote(negative) | |
| image_url = ( | |
| f"https://image.pollinations.ai/prompt/{encoded_prompt}" | |
| f"?width=1024&height=1024&nologo=true&model=flux&negative={encoded_negative}" | |
| ) | |
| # Verify the URL is reachable | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(image_url, timeout=120, follow_redirects=True) | |
| if response.status_code != 200: | |
| logger.error(f"Pollinations.AI Error ({response.status_code})") | |
| response.raise_for_status() | |
| logger.info(f"Pollinations.AI image URL ready: {image_url}") | |
| else: | |
| logger.info("Using standard OpenAI DALL-E 3...") | |
| client = OpenAI( | |
| api_key=settings.openai_api_key, base_url=settings.openai_api_base if settings.openai_api_base else None | |
| ) | |
| response = client.images.generate( | |
| model="dall-e-3", | |
| prompt=prompt, | |
| size="1024x1024", | |
| quality="standard", | |
| n=1, | |
| ) | |
| image_url = response.data[0].url | |
| if not image_url: | |
| raise Exception("Failed to extract image URL from AI response") | |
| print(f"DEBUG: FINAL IMAGE RESULT: {image_url}") | |
| logger.info(f"Successfully generated AI image: {image_url}") | |
| return {"image_1": image_url} | |
| except Exception as e: | |
| logger.error(f"IMAGE GENERATION CRITICAL FAILURE: {str(e)}") | |
| raise e | |
| async def moderate_content(input: ModerateContentInput) -> Dict[str, Any]: | |
| """Check content and images for NSFW/moderation issues.""" | |
| logger.info(f"Running moderation checks on post {input.post.post_id}") | |
| # Use OpenAI Moderation API if available | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if api_key and not api_key.startswith("sk-xxx"): | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.post( | |
| "https://api.openai.com/v1/moderations", | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| json={"input": input.post.content}, | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| flagged = result["results"][0]["flagged"] | |
| return { | |
| "approved": not flagged, | |
| "nsfw_scores": result["results"][0]["categories"], | |
| "moderation_notes": "OpenAI automated check", | |
| } | |
| except Exception as e: | |
| logger.error(f"Moderation API failed: {e}") | |
| # Fallback to manual approval or simple keyword check | |
| # Simple keyword check as fallback | |
| blocked_words = ["nsfw", "violence", "hate"] | |
| content_lower = input.post.content.lower() | |
| for word in blocked_words: | |
| if word in content_lower: | |
| return {"approved": False, "moderation_notes": f"Flagged by keyword: {word}"} | |
| return { | |
| "approved": True, | |
| "nsfw_scores": {}, | |
| "moderation_notes": "Simple keyword check (Fallback)", | |
| } | |
| async def store_media_to_cdn(input: StoreMediaInput) -> Dict[str, str]: | |
| """Upload images to Cloudinary CDN.""" | |
| logger.info(f"Uploading media for post {input.post_id} to Cloudinary") | |
| settings = get_settings() | |
| # Configure cloudinary | |
| cloudinary.config( | |
| cloud_name=settings.cloudinary_cloud_name, | |
| api_key=settings.cloudinary_api_key, | |
| api_secret=settings.cloudinary_api_secret, | |
| secure=True, | |
| ) | |
| cdn_urls = {} | |
| for key, url in input.images.items(): | |
| try: | |
| # Upload to Cloudinary | |
| # Note: Cloudinary's SDK is blocking, so we run it in a thread if needed, | |
| # but for simplicity in this activity we'll call it directly. | |
| # In production, use asyncio loop.run_in_executor | |
| logger.info(f"Uploading {url} to Cloudinary...") | |
| # If the URL is a local path, Cloudinary handles it. | |
| # If it's a dummy URL like example.com, it might fail. | |
| if "example.com" in url: | |
| logger.warning(f"Skipping dummy URL upload: {url}") | |
| cdn_urls[key] = url | |
| continue | |
| response = cloudinary.uploader.upload(url, folder=f"ai_media_os/post_{input.post_id}", resource_type="auto") | |
| cdn_urls[key] = response.get("secure_url") | |
| logger.info(f"Successfully uploaded to Cloudinary: {cdn_urls[key]}") | |
| except Exception as e: | |
| logger.error(f"Failed to upload to Cloudinary: {e}") | |
| # Fallback to original URL | |
| cdn_urls[key] = url | |
| return cdn_urls | |
| async def save_post_draft(input: SavePostDraftInput) -> PostData: | |
| """Save post as draft in database.""" | |
| logger.info(f"Saving post draft for trend {input.trend_id}") | |
| async with AsyncSessionLocal() as session: | |
| new_post = PostModel( | |
| trend_id=input.trend_id, | |
| agent_version_id=input.agent_version_id, | |
| content=_strip_markdown(input.post.content), | |
| image_url=input.images.get("image_1"), | |
| platform=input.post.platform, | |
| status="draft", | |
| approval_status="pending", | |
| ) | |
| session.add(new_post) | |
| await session.commit() | |
| await session.refresh(new_post) | |
| return PostData( | |
| post_id=new_post.id, | |
| content=new_post.content, | |
| image_url=new_post.image_url, | |
| platform=new_post.platform, | |
| ) | |
| async def publish_to_platform(input: PublishToPlatformInput) -> PublishResult: | |
| """Publish post to social platform.""" | |
| logger.info(f"Publishing post {input.post.post_id} to {input.platform}") | |
| from src.services.social_media import SocialMediaPublisher | |
| publisher = SocialMediaPublisher() | |
| clean_text = _strip_markdown(input.post.content) | |
| result = await publisher.publish_post( | |
| platform=input.platform, | |
| text=clean_text, | |
| image_urls=[input.post.image_url] if input.post.image_url else [], | |
| ) | |
| if result["success"]: | |
| return PublishResult( | |
| success=True, | |
| platform_post_id=result.get("platform_post_id", "unknown"), | |
| error=None, | |
| ) | |
| else: | |
| return PublishResult( | |
| success=False, | |
| platform_post_id=None, | |
| error=result.get("error", "Unknown publishing error"), | |
| ) | |
| async def record_execution_metrics(input: RecordMetricsInput) -> None: | |
| """Record execution metrics and costs.""" | |
| logger.info( | |
| f"Recording metrics for post {input.post_id}: " | |
| f"status={input.status}, cost=${input.cost_usd}, tokens={input.token_count}" | |
| ) | |
| async with AsyncSessionLocal() as session: | |
| # Update post with metrics | |
| result = await session.execute(select(PostModel).where(PostModel.id == input.post_id)) | |
| post = result.scalar_one_or_none() | |
| if post: | |
| post.token_count = (post.token_count or 0) + input.token_count | |
| post.generation_cost = (post.generation_cost or 0.0) + input.cost_usd | |
| post.status = input.status | |
| if input.status == "published": | |
| post.published_at = datetime.utcnow() | |
| # Track aggregate API usage | |
| period_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) | |
| usage_result = await session.execute( | |
| select(ApiUsage).where(ApiUsage.service == "openai", ApiUsage.period_start == period_start) | |
| ) | |
| usage = usage_result.scalar_one_or_none() | |
| if not usage: | |
| usage = ApiUsage( | |
| service="openai", | |
| period_start=period_start, | |
| period_end=period_start.replace(hour=23, minute=59, second=59), | |
| request_count=1, | |
| token_count=input.token_count, | |
| cost_usd=input.cost_usd, | |
| ) | |
| session.add(usage) | |
| else: | |
| usage.request_count += 1 | |
| usage.token_count += input.token_count | |
| usage.cost_usd += input.cost_usd | |
| await session.commit() | |
| # Workflow definition | |
| class TrendToPostPublishWorkflow: | |
| """ | |
| Main workflow: Trend → Content Generation → Moderation → Publishing. | |
| Flow: | |
| 1. Fetch trend details | |
| 2. Generate post content (LLM) | |
| 3. Generate images | |
| 4. Moderate content | |
| 5. Upload to CDN | |
| 6. Save draft in DB | |
| 7. Wait for moderator approval (signal) | |
| 8. Publish to platform | |
| 9. Record metrics | |
| """ | |
| def __init__(self): | |
| self._is_approved: Optional[bool] = None | |
| def approve_post(self, approved: bool) -> None: | |
| self._is_approved = approved | |
| async def run( | |
| self, | |
| trend_id: int, | |
| agent_version_id: int, | |
| platform: str = "instagram", | |
| ) -> Dict[str, Any]: | |
| """Execute the workflow.""" | |
| logger.info(f"Starting workflow for trend_id={trend_id}") | |
| try: | |
| # Step 1: Fetch trend | |
| trend = await workflow.execute_activity( | |
| fetch_trend_details, | |
| trend_id, | |
| start_to_close_timeout=timedelta(seconds=30), | |
| ) | |
| # Step 2: Generate content | |
| post = await workflow.execute_activity( | |
| generate_post_content, | |
| GeneratePostInput(trend=trend, agent_version_id=agent_version_id), | |
| start_to_close_timeout=timedelta(minutes=5), | |
| retry_policy=RetryPolicy( | |
| maximum_attempts=5, | |
| initial_interval=timedelta(seconds=2), | |
| backoff_coefficient=2.0, | |
| ), | |
| ) | |
| # Step 3: Generate images — use LLM-crafted image_prompt if available | |
| images = await workflow.execute_activity( | |
| generate_images, | |
| GenerateImagesInput( | |
| topic=trend.topic, | |
| count=1, | |
| image_prompt=post.get("image_prompt") if isinstance(post, dict) else None, | |
| ), | |
| start_to_close_timeout=timedelta(minutes=10), | |
| retry_policy=RetryPolicy( | |
| maximum_attempts=2, | |
| initial_interval=timedelta(seconds=5), | |
| ), | |
| ) | |
| # Step 4: Moderation | |
| moderation = await workflow.execute_activity( | |
| moderate_content, | |
| ModerateContentInput(post=post, images=images), | |
| start_to_close_timeout=timedelta(seconds=60), | |
| ) | |
| if not moderation.get("approved", False): | |
| logger.warning(f"Content moderation rejected for trend {trend_id}") | |
| return { | |
| "status": "rejected", | |
| "reason": "Moderation failed", | |
| "post_id": None, | |
| } | |
| # Step 5: Save draft first (so we get a real post_id) | |
| post = await workflow.execute_activity( | |
| save_post_draft, | |
| SavePostDraftInput( | |
| post=post, | |
| trend_id=trend_id, | |
| agent_version_id=agent_version_id, | |
| images=images, | |
| ), | |
| start_to_close_timeout=timedelta(seconds=30), | |
| ) | |
| # Step 6: Upload to CDN (now we have a real post_id) | |
| cdn_images = await workflow.execute_activity( | |
| store_media_to_cdn, | |
| StoreMediaInput(images=images, post_id=post.post_id), | |
| start_to_close_timeout=timedelta(seconds=60), | |
| ) | |
| # Update post image_url to CDN url | |
| post = PostData( | |
| post_id=post.post_id, | |
| content=post.content, | |
| image_url=cdn_images.get("image_1", post.image_url), | |
| platform=post.platform, | |
| ) | |
| # Step 7: Approval — auto or human-in-the-loop | |
| settings = get_settings() | |
| if settings.auto_approve: | |
| logger.info(f"AUTO_APPROVE enabled — skipping human review for post {post.post_id}") | |
| self._is_approved = True | |
| else: | |
| logger.info(f"Waiting for moderator approval of post {post.post_id}") | |
| await workflow.wait_condition( | |
| lambda: self._is_approved is not None, | |
| timeout=timedelta(hours=24), | |
| ) | |
| if not self._is_approved: | |
| logger.info(f"Post {post.post_id} was rejected by moderator") | |
| return { | |
| "status": "rejected_by_moderator", | |
| "post_id": post.post_id, | |
| } | |
| # Step 8: Publish | |
| result = await workflow.execute_activity( | |
| publish_to_platform, | |
| PublishToPlatformInput(post=post, platform=platform), | |
| start_to_close_timeout=timedelta(minutes=2), | |
| retry_policy=RetryPolicy( | |
| maximum_attempts=3, | |
| initial_interval=timedelta(seconds=5), | |
| ), | |
| ) | |
| if not result.success: | |
| logger.error(f"Publishing failed for post {post.post_id}: {result.error}") | |
| return { | |
| "status": "publish_failed", | |
| "post_id": post.post_id, | |
| "error": result.error, | |
| } | |
| # Step 9: Record metrics | |
| await workflow.execute_activity( | |
| record_execution_metrics, | |
| RecordMetricsInput( | |
| post_id=post.post_id, | |
| status="published", | |
| platform_post_id=result.platform_post_id, | |
| cost_usd=0.50, | |
| token_count=250, | |
| ), | |
| start_to_close_timeout=timedelta(seconds=30), | |
| ) | |
| logger.info(f"Workflow completed successfully for post {post.post_id}") | |
| return { | |
| "status": "success", | |
| "post_id": post.post_id, | |
| "platform_post_id": result.platform_post_id, | |
| } | |
| except Exception as e: | |
| logger.error(f"Workflow error: {e}") | |
| return { | |
| "status": "failed", | |
| "error": str(e), | |
| } | |