File size: 9,512 Bytes
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f133f15
 
 
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
f133f15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97d0e8f
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
97d0e8f
c84fdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""FastAPI application entry point."""

import warnings
warnings.filterwarnings("ignore", message=".*garbage collector.*non-checked-in connection.*")
warnings.filterwarnings("ignore", message=".*allowed_objects.*")
try:
    from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
    warnings.filterwarnings("ignore", category=LangChainPendingDeprecationWarning)
except ImportError:
    pass

from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime
from src.models.schemas import HealthResponse
from src.utils.health import get_system_health
from src.utils.logging import export_logger
from src.utils.database import init_db, close_db
from src import __version__
from src.api.middleware.firebase_auth import get_current_user

logger = export_logger

# Create FastAPI app
app = FastAPI(
    title="AI Media OS",
    description="Durable, scalable AI-driven social media platform",
    version=__version__,
)

# Add CORS middleware
import os

_extra_origins = [o.strip() for o in os.getenv("ALLOWED_ORIGINS", "").split(",") if o.strip()]

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:5173",
        "http://localhost:3000",
        "http://127.0.0.1:5173",
        *_extra_origins,
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event("startup")
async def startup_event():
    """Initialize database on startup."""
    logger.info("Starting up AI Media OS")
    await init_db()
    from src.utils.migrate import run_migrations
    await run_migrations()
    from src.utils.triggers import install_triggers
    await install_triggers()
    from src.utils.notify import start_listener
    start_listener()
    from src.services.scheduler import start_scheduler
    start_scheduler()
    # Check Instagram token expiry and create notification if < 14 days
    import asyncio
    asyncio.create_task(_check_instagram_token())


@app.on_event("shutdown")
async def shutdown_event():
    """Clean up on shutdown."""
    logger.info("Shutting down AI Media OS")
    from src.services.scheduler import stop_scheduler
    stop_scheduler()
    from src.utils.notify import stop_listener
    stop_listener()
    from src.utils.database import close_db
    await close_db()


async def _check_instagram_token():
    """On startup, create a notification if the Instagram token expires soon."""
    import os, httpx
    from datetime import datetime, timezone
    try:
        token = os.getenv("INSTAGRAM_ACCESS_TOKEN", "")
        if not token:
            return
        async with httpx.AsyncClient(timeout=8) as client:
            resp = await client.get(
                "https://graph.facebook.com/v22.0/debug_token",
                params={"input_token": token, "access_token": token},
            )
        data = resp.json().get("data", {})
        if not data.get("is_valid"):
            return
        expires_at_ts = data.get("expires_at")
        if not expires_at_ts or expires_at_ts == 0:
            return
        expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc)
        days = (expires_at - datetime.now(tz=timezone.utc)).days
        if days <= 14:
            from src.utils.database import AsyncSessionLocal
            from src.models.database import Notification
            from sqlalchemy import select
            async with AsyncSessionLocal() as session:
                exists = (await session.execute(
                    select(Notification).where(Notification.type == "ig_token_expiry")
                    .order_by(Notification.created_at.desc()).limit(1)
                )).scalar_one_or_none()
                if not exists:
                    session.add(Notification(
                        type="ig_token_expiry",
                        title="Instagram token expiring soon",
                        message=f"Your Instagram access token expires in {days} day{'s' if days != 1 else ''}. Go to Settings β†’ Social to refresh it.",
                        is_read=False,
                    ))
                    await session.commit()
                    logger.warning(f"[Startup] Instagram token expires in {days} days β€” notification created")
    except Exception as e:
        logger.debug(f"[Startup] Token check skipped: {e}")


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    health = await get_system_health()
    return HealthResponse(
        status="healthy" if all(v == "healthy" for v in health.values()) else "degraded",
        version=__version__,
        database=health["database"],
        temporal=health["temporal"],
        redis=health["redis"],
        timestamp=datetime.utcnow(),
    )


@app.get("/api/ping")
async def ping():
    """Ultra-lightweight liveness check β€” no DB, no auth."""
    return {"ok": True}


@app.get("/")
async def root():
    """Root endpoint."""
    return {
        "name": "AI Media OS",
        "version": __version__,
        "status": "running",
        "timestamp": datetime.utcnow(),
    }


# Import and include routers
from src.api.routes import trends, posts, agents, workflows, moderator, ws, music, agent_team, knowledge, planner, seed

protected_dependencies = [Depends(get_current_user)]

app.include_router(trends.router, prefix="/api/trends", tags=["trends"], dependencies=protected_dependencies)
app.include_router(posts.router, prefix="/api/posts", tags=["posts"], dependencies=protected_dependencies)
app.include_router(agents.router, prefix="/api/agents", tags=["agents"], dependencies=protected_dependencies)
app.include_router(workflows.router, prefix="/api/workflows", tags=["workflows"], dependencies=protected_dependencies)
app.include_router(moderator.router, prefix="/api/moderator", tags=["moderator"], dependencies=protected_dependencies)
app.include_router(ws.router, prefix="/api", tags=["ws"])
app.include_router(music.router, prefix="/api/music", tags=["music"], dependencies=protected_dependencies)
app.include_router(agent_team.router, prefix="/api/agent-team", tags=["agent-team"], dependencies=protected_dependencies)
app.include_router(knowledge.router, prefix="/api/knowledge", tags=["knowledge"], dependencies=protected_dependencies)
app.include_router(planner.router, prefix="/api/planner", tags=["planner"], dependencies=protected_dependencies)
app.include_router(seed.router, prefix="/api/seed", tags=["seed"], dependencies=protected_dependencies)

# WebSocket registered last, directly on the app β€” bypasses all router-level dependencies
from fastapi import WebSocket as _WS, Query as _Q
from src.api.routes.agent_team import standup_ws as _standup_ws
@app.websocket("/api/agent-team/ws")
async def _agent_team_ws(websocket: _WS, token: str = _Q(default=None)):
    await _standup_ws(websocket, token)

logger.info(f"AI Media OS v{__version__} initialized")


# ── Scheduler status WebSocket ────────────────────────────────────────────────
import asyncio as _asyncio
import json as _json
from fastapi import WebSocket as _WSched, WebSocketDisconnect as _WSDisco, Query as _QSched

_sched_ws_clients: set[_WSched] = set()


async def _sched_status_payload() -> dict:
    import time
    from src.services.scheduler import scheduler
    from src.config import get_settings
    from src.models.database import Post
    from src.utils.database import AsyncSessionLocal
    from sqlalchemy import select, func
    settings = get_settings()
    jobs = [
        {"id": j.id, "name": j.name, "next_run": j.next_run_time.isoformat() if j.next_run_time else None}
        for j in scheduler.get_jobs()
    ]
    today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
    async with AsyncSessionLocal() as session:
        today_count = (await session.execute(
            select(func.count(Post.id)).where(Post.created_at >= today_start)
        )).scalar()
        published_today = (await session.execute(
            select(func.count(Post.id)).where(Post.published_at >= today_start, Post.status == "published")
        )).scalar()
    return {
        "scheduler_running": scheduler.running,
        "posts_per_day": settings.posts_per_day,
        "auto_approve": settings.auto_approve,
        "jobs": jobs,
        "today_generated": today_count,
        "today_published": published_today,
    }


@app.websocket("/api/scheduler/ws")
async def scheduler_ws(websocket: _WSched, token: str = _QSched(default=None)):
    from src.api.middleware.firebase_auth import verify_firebase_token_string
    try:
        await verify_firebase_token_string(token or "")
    except Exception:
        await websocket.close(code=4001)
        return

    await websocket.accept()
    _sched_ws_clients.add(websocket)
    try:
        # Push immediately on connect
        await websocket.send_text(_json.dumps(await _sched_status_payload()))
        while True:
            # Push every 30s; also listen for client pings
            try:
                await _asyncio.wait_for(websocket.receive_text(), timeout=30)
            except _asyncio.TimeoutError:
                pass
            except _WSDisco:
                break
            await websocket.send_text(_json.dumps(await _sched_status_payload()))
    except Exception:
        pass
    finally:
        _sched_ws_clients.discard(websocket)