| import os |
| import asyncio |
| from fastapi import FastAPI, HTTPException, Body |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import gradio as gr |
|
|
| from db import init_db, SessionLocal, Post, Keyword |
| from schemas import ( |
| SummarizeTrendsRequest, GenerateWeekPlanRequest, PostOut, |
| KeywordUpsertRequest, CalendarOut |
| ) |
| from services.llm import summarize_trends_llm, generate_week_plan_llm |
| from services.trend_monitor import fetch_trend_samples |
| from services.scheduler import scheduler, schedule_post_job, ensure_scheduler_started |
|
|
| from ui import build_ui |
|
|
| |
| app = FastAPI(title="SNS運用AIライト", version="0.1.0") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], |
| ) |
|
|
| |
| init_db() |
| ensure_scheduler_started(scheduler) |
|
|
| |
| @app.post("/api/summarize_trends") |
| async def summarize_trends(req: SummarizeTrendsRequest): |
| |
| items = await fetch_trend_samples(req.platforms, req.keywords, req.limit) |
| |
| summary = await summarize_trends_llm(items, req.brand, req.language) |
| return {"items": items, "summary": summary} |
|
|
| |
| @app.post("/api/generate_week_plan", response_model=list[PostOut]) |
| async def generate_week_plan(req: GenerateWeekPlanRequest): |
| posts = await generate_week_plan_llm( |
| brand=req.brand, |
| language=req.language, |
| platforms=req.platforms, |
| keywords=req.keywords, |
| start_date=req.start_date, |
| tone=req.tone, |
| cta=req.cta, |
| image_style_hint=req.image_style_hint, |
| ) |
| |
| session = SessionLocal() |
| try: |
| out = [] |
| for p in posts: |
| post = Post( |
| platform=p.platform, |
| scheduled_at=p.scheduled_at, |
| text=p.text, |
| image_prompt=p.image_prompt, |
| status="draft", |
| ) |
| session.add(post) |
| session.flush() |
| out.append(PostOut.model_validate({ |
| "id": post.id, |
| "platform": post.platform, |
| "scheduled_at": post.scheduled_at, |
| "text": post.text, |
| "image_prompt": post.image_prompt, |
| "status": post.status, |
| })) |
| session.commit() |
| return out |
| finally: |
| session.close() |
|
|
| |
| @app.post("/api/approve_post/{post_id}", response_model=PostOut) |
| async def approve_post(post_id: int): |
| session = SessionLocal() |
| try: |
| post = session.get(Post, post_id) |
| if not post: |
| raise HTTPException(404, "post not found") |
| post.status = "approved" |
| session.commit() |
| return PostOut.from_orm(post) |
| finally: |
| session.close() |
|
|
| |
| class ScheduleIn(BaseModel): |
| scheduled_at: str |
|
|
| @app.post("/api/schedule_post/{post_id}", response_model=PostOut) |
| async def schedule_post(post_id: int, body: ScheduleIn): |
| session = SessionLocal() |
| try: |
| post = session.get(Post, post_id) |
| if not post: |
| raise HTTPException(404, "post not found") |
| if post.status != "approved": |
| raise HTTPException(400, "post must be approved before scheduling") |
| post.scheduled_at = body.scheduled_at |
| post.status = "scheduled" |
| session.commit() |
|
|
| |
| schedule_post_job(post.id, post.platform, post.text, post.image_prompt, post.scheduled_at) |
| return PostOut.from_orm(post) |
| finally: |
| session.close() |
|
|
| |
| @app.get("/api/calendar", response_model=CalendarOut) |
| async def calendar(): |
| session = SessionLocal() |
| try: |
| posts = session.query(Post).order_by(Post.scheduled_at.asc().nulls_last()).all() |
| return CalendarOut(events=[PostOut.from_orm(p) for p in posts]) |
| finally: |
| session.close() |
|
|
| |
| @app.post("/api/keywords") |
| async def upsert_keywords(req: KeywordUpsertRequest): |
| session = SessionLocal() |
| try: |
| |
| session.query(Keyword).delete() |
| for kw in req.keywords: |
| session.add(Keyword(text=kw)) |
| session.commit() |
| return {"ok": True, "count": len(req.keywords)} |
| finally: |
| session.close() |
|
|
| @app.get("/api/keywords") |
| async def list_keywords(): |
| session = SessionLocal() |
| try: |
| kws = [k.text for k in session.query(Keyword).all()] |
| return {"keywords": kws} |
| finally: |
| session.close() |
|
|
| |
| demo = build_ui(app) |
| app = gr.mount_gradio_app(app, demo, path="/") |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False) |
|
|