AI_SNS / app.py
Corin1998's picture
Create app.py
9de15d8 verified
import os
import asyncio
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import gradio as gr
from db import init_db, SessionLocal, Post, Keyword
from schemas import (
SummarizeTrendsRequest, GenerateWeekPlanRequest, PostOut,
KeywordUpsertRequest, CalendarOut
)
from services.llm import summarize_trends_llm, generate_week_plan_llm
from services.trend_monitor import fetch_trend_samples
from services.scheduler import scheduler, schedule_post_job, ensure_scheduler_started
from ui import build_ui
# --- FastAPI 基本設定 ---
app = FastAPI(title="SNS運用AIライト", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# --- DB 初期化 & スケジューラ ---
init_db()
ensure_scheduler_started(scheduler)
# --- API: トレンド要約 ---
@app.post("/api/summarize_trends")
async def summarize_trends(req: SummarizeTrendsRequest):
# 1) キーワード監視(疑似 or アダプタ)
items = await fetch_trend_samples(req.platforms, req.keywords, req.limit)
# 2) LLM 要約
summary = await summarize_trends_llm(items, req.brand, req.language)
return {"items": items, "summary": summary}
# --- API: 1週間分の投稿案生成 ---
@app.post("/api/generate_week_plan", response_model=list[PostOut])
async def generate_week_plan(req: GenerateWeekPlanRequest):
posts = await generate_week_plan_llm(
brand=req.brand,
language=req.language,
platforms=req.platforms,
keywords=req.keywords,
start_date=req.start_date,
tone=req.tone,
cta=req.cta,
image_style_hint=req.image_style_hint,
)
# DB保存
session = SessionLocal()
try:
out = []
for p in posts:
post = Post(
platform=p.platform,
scheduled_at=p.scheduled_at,
text=p.text,
image_prompt=p.image_prompt,
status="draft",
)
session.add(post)
session.flush()
out.append(PostOut.model_validate({
"id": post.id,
"platform": post.platform,
"scheduled_at": post.scheduled_at,
"text": post.text,
"image_prompt": post.image_prompt,
"status": post.status,
}))
session.commit()
return out
finally:
session.close()
# --- API: 承認 ---
@app.post("/api/approve_post/{post_id}", response_model=PostOut)
async def approve_post(post_id: int):
session = SessionLocal()
try:
post = session.get(Post, post_id)
if not post:
raise HTTPException(404, "post not found")
post.status = "approved"
session.commit()
return PostOut.from_orm(post)
finally:
session.close()
# --- API: 予約設定(承認済みのみ) ---
class ScheduleIn(BaseModel):
scheduled_at: str # ISO
@app.post("/api/schedule_post/{post_id}", response_model=PostOut)
async def schedule_post(post_id: int, body: ScheduleIn):
session = SessionLocal()
try:
post = session.get(Post, post_id)
if not post:
raise HTTPException(404, "post not found")
if post.status != "approved":
raise HTTPException(400, "post must be approved before scheduling")
post.scheduled_at = body.scheduled_at
post.status = "scheduled"
session.commit()
# APScheduler に登録
schedule_post_job(post.id, post.platform, post.text, post.image_prompt, post.scheduled_at)
return PostOut.from_orm(post)
finally:
session.close()
# --- API: カレンダー ---
@app.get("/api/calendar", response_model=CalendarOut)
async def calendar():
session = SessionLocal()
try:
posts = session.query(Post).order_by(Post.scheduled_at.asc().nulls_last()).all()
return CalendarOut(events=[PostOut.from_orm(p) for p in posts])
finally:
session.close()
# --- API: キーワード保存/取得 ---
@app.post("/api/keywords")
async def upsert_keywords(req: KeywordUpsertRequest):
session = SessionLocal()
try:
# 単純化: 既存削除→再作成
session.query(Keyword).delete()
for kw in req.keywords:
session.add(Keyword(text=kw))
session.commit()
return {"ok": True, "count": len(req.keywords)}
finally:
session.close()
@app.get("/api/keywords")
async def list_keywords():
session = SessionLocal()
try:
kws = [k.text for k in session.query(Keyword).all()]
return {"keywords": kws}
finally:
session.close()
# --- Gradio UI を FastAPI にマウント ---
demo = build_ui(app)
app = gr.mount_gradio_app(app, demo, path="/")
# Uvicornローカル実行用(HF Spacesでは不要)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)