Spaces:
Running
Running
| """FastAPI application entry point.""" | |
| import warnings | |
| warnings.filterwarnings("ignore", message=".*garbage collector.*non-checked-in connection.*") | |
| warnings.filterwarnings("ignore", message=".*allowed_objects.*") | |
| try: | |
| from langchain_core._api.deprecation import LangChainPendingDeprecationWarning | |
| warnings.filterwarnings("ignore", category=LangChainPendingDeprecationWarning) | |
| except ImportError: | |
| pass | |
| from fastapi import Depends, FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from datetime import datetime | |
| from src.models.schemas import HealthResponse | |
| from src.utils.health import get_system_health | |
| from src.utils.logging import export_logger | |
| from src.utils.database import init_db, close_db | |
| from src import __version__ | |
| from src.api.middleware.firebase_auth import get_current_user | |
| logger = export_logger | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="AI Media OS", | |
| description="Durable, scalable AI-driven social media platform", | |
| version=__version__, | |
| ) | |
| # Add CORS middleware | |
| import os | |
| _extra_origins = [o.strip() for o in os.getenv("ALLOWED_ORIGINS", "").split(",") if o.strip()] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5173", | |
| "http://localhost:3000", | |
| "http://127.0.0.1:5173", | |
| *_extra_origins, | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Initialize database on startup.""" | |
| logger.info("Starting up AI Media OS") | |
| await init_db() | |
| from src.utils.migrate import run_migrations | |
| await run_migrations() | |
| from src.utils.triggers import install_triggers | |
| await install_triggers() | |
| from src.utils.notify import start_listener | |
| start_listener() | |
| from src.services.scheduler import start_scheduler | |
| start_scheduler() | |
| # Check Instagram token expiry and create notification if < 14 days | |
| import asyncio | |
| asyncio.create_task(_check_instagram_token()) | |
| async def shutdown_event(): | |
| """Clean up on shutdown.""" | |
| logger.info("Shutting down AI Media OS") | |
| from src.services.scheduler import stop_scheduler | |
| stop_scheduler() | |
| from src.utils.notify import stop_listener | |
| stop_listener() | |
| from src.utils.database import close_db | |
| await close_db() | |
| async def _check_instagram_token(): | |
| """On startup, create a notification if the Instagram token expires soon.""" | |
| import os, httpx | |
| from datetime import datetime, timezone | |
| try: | |
| token = os.getenv("INSTAGRAM_ACCESS_TOKEN", "") | |
| if not token: | |
| return | |
| async with httpx.AsyncClient(timeout=8) as client: | |
| resp = await client.get( | |
| "https://graph.facebook.com/v22.0/debug_token", | |
| params={"input_token": token, "access_token": token}, | |
| ) | |
| data = resp.json().get("data", {}) | |
| if not data.get("is_valid"): | |
| return | |
| expires_at_ts = data.get("expires_at") | |
| if not expires_at_ts or expires_at_ts == 0: | |
| return | |
| expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc) | |
| days = (expires_at - datetime.now(tz=timezone.utc)).days | |
| if days <= 14: | |
| from src.utils.database import AsyncSessionLocal | |
| from src.models.database import Notification | |
| from sqlalchemy import select | |
| async with AsyncSessionLocal() as session: | |
| exists = (await session.execute( | |
| select(Notification).where(Notification.type == "ig_token_expiry") | |
| .order_by(Notification.created_at.desc()).limit(1) | |
| )).scalar_one_or_none() | |
| if not exists: | |
| session.add(Notification( | |
| type="ig_token_expiry", | |
| title="Instagram token expiring soon", | |
| message=f"Your Instagram access token expires in {days} day{'s' if days != 1 else ''}. Go to Settings β Social to refresh it.", | |
| is_read=False, | |
| )) | |
| await session.commit() | |
| logger.warning(f"[Startup] Instagram token expires in {days} days β notification created") | |
| except Exception as e: | |
| logger.debug(f"[Startup] Token check skipped: {e}") | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| health = await get_system_health() | |
| return HealthResponse( | |
| status="healthy" if all(v == "healthy" for v in health.values()) else "degraded", | |
| version=__version__, | |
| database=health["database"], | |
| temporal=health["temporal"], | |
| redis=health["redis"], | |
| timestamp=datetime.utcnow(), | |
| ) | |
| async def ping(): | |
| """Ultra-lightweight liveness check β no DB, no auth.""" | |
| return {"ok": True} | |
| async def root(): | |
| """Root endpoint.""" | |
| return { | |
| "name": "AI Media OS", | |
| "version": __version__, | |
| "status": "running", | |
| "timestamp": datetime.utcnow(), | |
| } | |
| # Import and include routers | |
| from src.api.routes import trends, posts, agents, workflows, moderator, ws, music, agent_team, knowledge, planner, seed | |
| protected_dependencies = [Depends(get_current_user)] | |
| app.include_router(trends.router, prefix="/api/trends", tags=["trends"], dependencies=protected_dependencies) | |
| app.include_router(posts.router, prefix="/api/posts", tags=["posts"], dependencies=protected_dependencies) | |
| app.include_router(agents.router, prefix="/api/agents", tags=["agents"], dependencies=protected_dependencies) | |
| app.include_router(workflows.router, prefix="/api/workflows", tags=["workflows"], dependencies=protected_dependencies) | |
| app.include_router(moderator.router, prefix="/api/moderator", tags=["moderator"], dependencies=protected_dependencies) | |
| app.include_router(ws.router, prefix="/api", tags=["ws"]) | |
| app.include_router(music.router, prefix="/api/music", tags=["music"], dependencies=protected_dependencies) | |
| app.include_router(agent_team.router, prefix="/api/agent-team", tags=["agent-team"], dependencies=protected_dependencies) | |
| app.include_router(knowledge.router, prefix="/api/knowledge", tags=["knowledge"], dependencies=protected_dependencies) | |
| app.include_router(planner.router, prefix="/api/planner", tags=["planner"], dependencies=protected_dependencies) | |
| app.include_router(seed.router, prefix="/api/seed", tags=["seed"], dependencies=protected_dependencies) | |
| # WebSocket registered last, directly on the app β bypasses all router-level dependencies | |
| from fastapi import WebSocket as _WS, Query as _Q | |
| from src.api.routes.agent_team import standup_ws as _standup_ws | |
| async def _agent_team_ws(websocket: _WS, token: str = _Q(default=None)): | |
| await _standup_ws(websocket, token) | |
| logger.info(f"AI Media OS v{__version__} initialized") | |
| # ββ Scheduler status WebSocket ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import asyncio as _asyncio | |
| import json as _json | |
| from fastapi import WebSocket as _WSched, WebSocketDisconnect as _WSDisco, Query as _QSched | |
| _sched_ws_clients: set[_WSched] = set() | |
| async def _sched_status_payload() -> dict: | |
| import time | |
| from src.services.scheduler import scheduler | |
| from src.config import get_settings | |
| from src.models.database import Post | |
| from src.utils.database import AsyncSessionLocal | |
| from sqlalchemy import select, func | |
| settings = get_settings() | |
| jobs = [ | |
| {"id": j.id, "name": j.name, "next_run": j.next_run_time.isoformat() if j.next_run_time else None} | |
| for j in scheduler.get_jobs() | |
| ] | |
| today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) | |
| async with AsyncSessionLocal() as session: | |
| today_count = (await session.execute( | |
| select(func.count(Post.id)).where(Post.created_at >= today_start) | |
| )).scalar() | |
| published_today = (await session.execute( | |
| select(func.count(Post.id)).where(Post.published_at >= today_start, Post.status == "published") | |
| )).scalar() | |
| return { | |
| "scheduler_running": scheduler.running, | |
| "posts_per_day": settings.posts_per_day, | |
| "auto_approve": settings.auto_approve, | |
| "jobs": jobs, | |
| "today_generated": today_count, | |
| "today_published": published_today, | |
| } | |
| async def scheduler_ws(websocket: _WSched, token: str = _QSched(default=None)): | |
| from src.api.middleware.firebase_auth import verify_firebase_token_string | |
| try: | |
| await verify_firebase_token_string(token or "") | |
| except Exception: | |
| await websocket.close(code=4001) | |
| return | |
| await websocket.accept() | |
| _sched_ws_clients.add(websocket) | |
| try: | |
| # Push immediately on connect | |
| await websocket.send_text(_json.dumps(await _sched_status_payload())) | |
| while True: | |
| # Push every 30s; also listen for client pings | |
| try: | |
| await _asyncio.wait_for(websocket.receive_text(), timeout=30) | |
| except _asyncio.TimeoutError: | |
| pass | |
| except _WSDisco: | |
| break | |
| await websocket.send_text(_json.dumps(await _sched_status_payload())) | |
| except Exception: | |
| pass | |
| finally: | |
| _sched_ws_clients.discard(websocket) | |