Corin1998 commited on
Commit
9de15d8
·
verified ·
1 Parent(s): 0d6b9a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from fastapi import FastAPI, HTTPException, Body
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ import gradio as gr
7
+
8
+ from db import init_db, SessionLocal, Post, Keyword
9
+ from schemas import (
10
+ SummarizeTrendsRequest, GenerateWeekPlanRequest, PostOut,
11
+ KeywordUpsertRequest, CalendarOut
12
+ )
13
+ from services.llm import summarize_trends_llm, generate_week_plan_llm
14
+ from services.trend_monitor import fetch_trend_samples
15
+ from services.scheduler import scheduler, schedule_post_job, ensure_scheduler_started
16
+
17
+ from ui import build_ui
18
+
19
+ # --- FastAPI 基本設定 ---
20
+ app = FastAPI(title="SNS運用AIライト", version="0.1.0")
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
24
+ )
25
+
26
+ # --- DB 初期化 & スケジューラ ---
27
+ init_db()
28
+ ensure_scheduler_started(scheduler)
29
+
30
+ # --- API: トレンド要約 ---
31
+ @app.post("/api/summarize_trends")
32
+ async def summarize_trends(req: SummarizeTrendsRequest):
33
+ # 1) キーワード監視(疑似 or アダプタ)
34
+ items = await fetch_trend_samples(req.platforms, req.keywords, req.limit)
35
+ # 2) LLM 要約
36
+ summary = await summarize_trends_llm(items, req.brand, req.language)
37
+ return {"items": items, "summary": summary}
38
+
39
+ # --- API: 1週間分の投稿案生成 ---
40
+ @app.post("/api/generate_week_plan", response_model=list[PostOut])
41
+ async def generate_week_plan(req: GenerateWeekPlanRequest):
42
+ posts = await generate_week_plan_llm(
43
+ brand=req.brand,
44
+ language=req.language,
45
+ platforms=req.platforms,
46
+ keywords=req.keywords,
47
+ start_date=req.start_date,
48
+ tone=req.tone,
49
+ cta=req.cta,
50
+ image_style_hint=req.image_style_hint,
51
+ )
52
+ # DB保存
53
+ session = SessionLocal()
54
+ try:
55
+ out = []
56
+ for p in posts:
57
+ post = Post(
58
+ platform=p.platform,
59
+ scheduled_at=p.scheduled_at,
60
+ text=p.text,
61
+ image_prompt=p.image_prompt,
62
+ status="draft",
63
+ )
64
+ session.add(post)
65
+ session.flush()
66
+ out.append(PostOut.model_validate({
67
+ "id": post.id,
68
+ "platform": post.platform,
69
+ "scheduled_at": post.scheduled_at,
70
+ "text": post.text,
71
+ "image_prompt": post.image_prompt,
72
+ "status": post.status,
73
+ }))
74
+ session.commit()
75
+ return out
76
+ finally:
77
+ session.close()
78
+
79
+ # --- API: 承認 ---
80
+ @app.post("/api/approve_post/{post_id}", response_model=PostOut)
81
+ async def approve_post(post_id: int):
82
+ session = SessionLocal()
83
+ try:
84
+ post = session.get(Post, post_id)
85
+ if not post:
86
+ raise HTTPException(404, "post not found")
87
+ post.status = "approved"
88
+ session.commit()
89
+ return PostOut.from_orm(post)
90
+ finally:
91
+ session.close()
92
+
93
+ # --- API: 予約設定(承認済みのみ) ---
94
+ class ScheduleIn(BaseModel):
95
+ scheduled_at: str # ISO
96
+
97
+ @app.post("/api/schedule_post/{post_id}", response_model=PostOut)
98
+ async def schedule_post(post_id: int, body: ScheduleIn):
99
+ session = SessionLocal()
100
+ try:
101
+ post = session.get(Post, post_id)
102
+ if not post:
103
+ raise HTTPException(404, "post not found")
104
+ if post.status != "approved":
105
+ raise HTTPException(400, "post must be approved before scheduling")
106
+ post.scheduled_at = body.scheduled_at
107
+ post.status = "scheduled"
108
+ session.commit()
109
+
110
+ # APScheduler に登録
111
+ schedule_post_job(post.id, post.platform, post.text, post.image_prompt, post.scheduled_at)
112
+ return PostOut.from_orm(post)
113
+ finally:
114
+ session.close()
115
+
116
+ # --- API: カレンダー ---
117
+ @app.get("/api/calendar", response_model=CalendarOut)
118
+ async def calendar():
119
+ session = SessionLocal()
120
+ try:
121
+ posts = session.query(Post).order_by(Post.scheduled_at.asc().nulls_last()).all()
122
+ return CalendarOut(events=[PostOut.from_orm(p) for p in posts])
123
+ finally:
124
+ session.close()
125
+
126
+ # --- API: キーワード保存/取得 ---
127
+ @app.post("/api/keywords")
128
+ async def upsert_keywords(req: KeywordUpsertRequest):
129
+ session = SessionLocal()
130
+ try:
131
+ # 単純化: 既存削除→再作成
132
+ session.query(Keyword).delete()
133
+ for kw in req.keywords:
134
+ session.add(Keyword(text=kw))
135
+ session.commit()
136
+ return {"ok": True, "count": len(req.keywords)}
137
+ finally:
138
+ session.close()
139
+
140
+ @app.get("/api/keywords")
141
+ async def list_keywords():
142
+ session = SessionLocal()
143
+ try:
144
+ kws = [k.text for k in session.query(Keyword).all()]
145
+ return {"keywords": kws}
146
+ finally:
147
+ session.close()
148
+
149
+ # --- Gradio UI を FastAPI にマウント ---
150
+ demo = build_ui(app)
151
+ app = gr.mount_gradio_app(app, demo, path="/")
152
+
153
+ # Uvicornローカル実行用(HF Spacesでは不要)
154
+ if __name__ == "__main__":
155
+ import uvicorn
156
+ uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)