PR_IRminiSaaS / app.py
Corin1998's picture
Update app.py
d2dfe17 verified
import os
import gradio as gr
from typing import List, Optional
from sqlalchemy.orm import Session
from db import init_db, SessionLocal, Draft, DraftStatus, Tone, ContentType, Delivery
from ingest_utils import extract_from_pdf, extract_from_url
from llm_utils import generate_draft
from emailer import send_email
from x_client import post_to_x
from note_client import post_to_note
init_db() # SQLite を /data または /tmp に自動配置
def _session() -> Session:
return SessionLocal()
# ---------------------------
# Helpers
# ---------------------------
def _parse_urls(text: str) -> List[str]:
if not text:
return []
lines = [ln.strip() for ln in text.splitlines()]
return [ln for ln in lines if ln and not ln.startswith("#")]
def _concat_chunks(chunks: List[str], max_chars: int = 20000) -> str:
"""LLMに渡す素材を結合。長すぎる場合は先頭から安全にトリム。"""
joined = "\n\n---\n".join([c for c in chunks if c])
if len(joined) <= max_chars:
return joined
return joined[:max_chars]
# ---------------------------
# Core actions
# ---------------------------
def do_generate(content_type: str,
tone: str,
use_urls: bool, urls_multi: str,
use_pdfs: bool, pdf_files,
use_text: bool, text: str):
"""
URL複数 + PDF複数 + テキスト を同時投入して要約→ドラフト生成。
"""
try:
chunks = []
source_refs = []
# URLs
if use_urls:
for u in _parse_urls(urls_multi):
try:
chunks.append(extract_from_url(u))
source_refs.append(u)
except Exception as e:
chunks.append(f"[URL取得エラー: {u} / {e}]")
# PDFs (gr.File with file_count='multiple' returns a list)
if use_pdfs and pdf_files:
# pdf_files は None または list
pdf_list = pdf_files if isinstance(pdf_files, list) else [pdf_files]
for f in pdf_list:
if not f:
continue
try:
content = f.read() if hasattr(f, "read") else None
if content is None and hasattr(f, "name"):
# gradio 一部ランタイムで一時パスになる場合に備える
with open(f.name, "rb") as fp:
content = fp.read()
if content:
chunks.append(extract_from_pdf(content))
source_refs.append(getattr(f, "name", "uploaded.pdf"))
except Exception as e:
chunks.append(f"[PDF抽出エラー: {getattr(f, 'name', 'uploaded.pdf')} / {e}]")
# Text
if use_text and text:
chunks.append(text)
source_refs.append("text")
if not chunks:
return None, "", "", "", "", "入力が空です。URL/PDF/テキストのいずれかを指定してください。"
raw = _concat_chunks(chunks, max_chars=20000)
title, body_md, subj_a, subj_b = generate_draft(
raw, content_type, tone
)
# 保存
db = _session()
d = Draft(
source_type="mixed", # 複数ソース
source_ref=", ".join(source_refs)[:500],
raw_text=raw,
content_type=ContentType(content_type),
tone=Tone(tone),
title=title,
body_md=body_md,
subject_a=subj_a,
subject_b=subj_b,
status=DraftStatus.pending,
)
db.add(d); db.commit(); db.refresh(d); db.close()
msg = f"ドラフト生成完了: ID={d.id} / sources={len(source_refs)}"
return d.id, title, body_md, subj_a, subj_b, msg
except Exception as e:
return None, "", "", "", "", f"エラー: {e}"
def do_approve_save(draft_id: int, title: str, body_md: str,
subject_a: str, subject_b: str,
emails_csv: str, deliver_x: bool, deliver_note: bool):
db = _session()
d = db.get(Draft, int(draft_id))
if not d:
db.close()
return "対象ドラフトが見つかりません。"
d.title = title
d.body_md = body_md
d.subject_a = subject_a
d.subject_b = subject_b
d.deliver_email_list = emails_csv or ""
d.deliver_x = bool(deliver_x)
d.deliver_note = bool(deliver_note)
d.status = DraftStatus.approved
db.add(d); db.commit(); db.refresh(d); db.close()
return f"承認・保存しました: ID={d.id}"
def do_deliver(draft_id: int):
db = _session()
d = db.get(Draft, int(draft_id))
if not d:
db.close()
return "対象ドラフトが見つかりません。"
if d.status not in [DraftStatus.approved, DraftStatus.scheduled]:
db.close()
return f"配信できません。ステータス={d.status.value}(approved/scheduled 必須)"
results = []
# email
if d.deliver_email_list:
recipients = [e.strip() for e in d.deliver_email_list.split(",") if e.strip()]
res = send_email(recipients, d.subject_a or d.title or "お知らせ",
d.subject_b or d.title or "お知らせ",
d.body_md or "")
delivery = Delivery(draft_id=d.id, channel="email", payload={"recipients": recipients}, result=res)
db.add(delivery); db.commit(); db.refresh(delivery)
results.append({"email": res})
# X
if d.deliver_x:
text = f"{d.title}\n\n{(d.body_md or '')[:220]}"
res = post_to_x(text)
delivery = Delivery(draft_id=d.id, channel="x", payload={"text": text[:280]}, result=res)
db.add(delivery); db.commit(); db.refresh(delivery)
results.append({"x": res})
# note
if d.deliver_note:
res = post_to_note(d.title or "お知らせ", d.body_md or "")
delivery = Delivery(draft_id=d.id, channel="note", payload={"title": d.title}, result=res)
db.add(delivery); db.commit(); db.refresh(delivery)
results.append({"note": res})
d.status = DraftStatus.sent
db.add(d); db.commit(); db.refresh(d); db.close()
return f"配信完了: {results}"
def load_history():
db = _session()
rows = db.query(Draft).order_by(Draft.id.desc()).limit(30).all()
db.close()
headers = ["id","status","type","tone","title","emails","X","note"]
data = [[r.id, r.status.value, r.content_type.value, r.tone.value, r.title or "",
r.deliver_email_list or "", "✓" if r.deliver_x else "", "✓" if r.deliver_note else ""] for r in rows]
return gr.Dataframe(headers=headers, value=data)
def load_draft(draft_id: int):
if not draft_id:
return "", "", "", "", "ドラフトIDを入力してください。"
db = _session()
d = db.get(Draft, int(draft_id))
db.close()
if not d:
return "", "", "", "", "見つかりませんでした。"
return d.title or "", d.body_md or "", d.subject_a or "", d.subject_b or "", f"読み込み: ID={d.id}"
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("# PR/IR MiniSaaS(HF-only)")
with gr.Row():
with gr.Column(scale=1):
content_type = gr.Dropdown(
label="コンテンツ種別", choices=[e.value for e in ContentType], value="press_release")
tone = gr.Dropdown(
label="トーン", choices=[e.value for e in Tone], value="neutral")
gr.Markdown("### 入力ソース(複数可)")
use_urls = gr.Checkbox(label="URLを使用", value=True)
urls_multi = gr.Textbox(
label="URL(1行1件・複数可)",
placeholder="https://example.com/one\nhttps://example.com/two",
lines=4
)
use_pdfs = gr.Checkbox(label="PDFを使用", value=True)
pdf_files = gr.File(
label="PDFファイル(複数可)",
file_types=[".pdf"],
file_count="multiple" # Gradio v4
)
use_text = gr.Checkbox(label="テキストを使用", value=True)
text = gr.Textbox(label="テキスト(任意)", placeholder="本文を貼り付け...", lines=8)
gen_btn = gr.Button("ドラフト生成", variant="primary")
gen_msg = gr.Markdown()
with gr.Column(scale=2):
draft_id = gr.Number(label="ドラフトID", precision=0, interactive=False)
title = gr.Textbox(label="タイトル(H1)", lines=1)
body_md = gr.Textbox(label="本文(Markdown)", lines=18)
subject_a = gr.Textbox(label="件名A(メールABテスト)", lines=1)
subject_b = gr.Textbox(label="件名B(メールABテスト)", lines=1)
emails_csv = gr.Textbox(label="メール宛先(カンマ区切り)", placeholder="a@example.com, b@example.com")
deliver_x = gr.Checkbox(label="Xにも投稿", value=False)
deliver_note = gr.Checkbox(label="noteにも投稿(Webhook)", value=False)
with gr.Row():
approve_btn = gr.Button("承認して保存")
deliver_btn = gr.Button("今すぐ配信")
action_msg = gr.Markdown()
gr.Markdown("## 履歴")
hist_btn = gr.Button("履歴を更新")
hist_df = gr.Dataframe(headers=["id","status","type","tone","title","emails","X","note"], value=[])
with gr.Row():
load_id = gr.Number(label="ドラフトIDを読み込み", precision=0)
load_btn = gr.Button("読み込み")
load_msg = gr.Markdown()
# wiring
gen_btn.click(
do_generate,
inputs=[content_type, tone, use_urls, urls_multi, use_pdfs, pdf_files, use_text, text],
outputs=[draft_id, title, body_md, subject_a, subject_b, gen_msg],
)
approve_btn.click(
do_approve_save,
inputs=[draft_id, title, body_md, subject_a, subject_b, emails_csv, deliver_x, deliver_note],
outputs=action_msg,
)
deliver_btn.click(
do_deliver,
inputs=[draft_id],
outputs=action_msg,
)
hist_btn.click(
load_history, inputs=None, outputs=hist_df
)
load_btn.click(
load_draft, inputs=[load_id], outputs=[title, body_md, subject_a, subject_b, load_msg]
)
if __name__ == "__main__":
demo.launch()