Corin1998's picture
Update app/main.py
00cc288 verified
from __future__ import annotations
from typing import List, Optional
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
import gradio as gr
import pandas as pd
from .config import settings
from .lib.topic import TopicEngine
app = FastAPI(
title="SNS Analyzer",
default_response_class=ORJSONResponse,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---- Engine (設定読み込み時に embedding_model を参照) ----
_engine = TopicEngine()
# ---------- REST API ----------
class AnalyzeRequest(BaseModel):
texts: List[str]
n_clusters: Optional[int] = None
class AnalyzeResponse(BaseModel):
summary: str
points: List[dict]
topics: List[dict]
@app.get("/health")
def health():
return {"status": "ok", "message": "FastAPI up"}
@app.post("/api/analyze", response_model=AnalyzeResponse)
def api_analyze(req: AnalyzeRequest):
res = _engine.analyze(req.texts, req.n_clusters)
return AnalyzeResponse(summary=res.summary, points=res.points, topics=res.topics)
# ---------- Gradio UI ----------
def _run_ui(text_block: str, k: int | None):
lines = [ln.strip() for ln in (text_block or "").splitlines() if ln.strip()]
res = _engine.analyze(lines, k if k and k > 0 else None)
df_points = pd.DataFrame(res.points) if res.points else pd.DataFrame(columns=["x", "y", "cluster", "text"])
df_topics = pd.DataFrame(res.topics) if res.topics else pd.DataFrame(columns=["id", "size", "top_terms"])
# ScatterPlot は DataFrame をそのまま渡せます
return res.summary, df_points, df_topics
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("## 🔎 SNS Analyzer (クラスタリング + 要約) — OpenAI API を使用")
with gr.Row():
inp = gr.Textbox(
label="テキスト(1行1件で貼り付け)",
placeholder="例:\n新商品の発売が楽しみ...\n価格が上がっているのが気になる...\nUIがすごく改善されたと思う...\n...",
lines=10,
)
with gr.Column():
k = gr.Slider(0, 10, value=0, step=1, label="クラスタ数(0=自動)")
btn = gr.Button("分析する", variant="primary")
summary = gr.Markdown(label="要約")
scatter = gr.ScatterPlot(
x="x", y="y", color="cluster", tooltip="text", label="2D配置(UMAP)", height=420
)
topics = gr.Dataframe(headers=["id", "size", "top_terms"], label="クラスタ語(上位TF-IDF)", interactive=False)
btn.click(_run_ui, [inp, k], [summary, scatter, topics])
# ルートにマウント(/docs は FastAPI のまま利用可能)
app = gr.mount_gradio_app(app, demo, path="/")