ESG_IR_RAGbot / app.py
Corin1998's picture
Upload 7 files
870d2c7 verified
from __future__ import annotations
import os, json, yaml, subprocess, sys, pathlib, traceback, shutil, re, getpass, stat
from typing import List, Dict, Tuple, Iterable, Optional
from fastapi import FastAPI, Body
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
# ===== config =====
DEFAULT_CFG = {
"app_name": "IR/ESG RAG Bot (OpenAI, 8 languages)",
"embedding_model": "text-embedding-3-large",
"normalize_embeddings": True,
"chunk": {"target_chars": 1400, "overlap_chars": 180},
"retrieval": {"top_k": 6, "score_threshold": 0.15, "mmr_lambda": 0.3},
"llm": {
"model": "gpt-4o-mini", # 必要に応じて利用可能なモデルに変更
"max_output_tokens": 700,
"temperature": 0.2,
"system_prompt": (
"あなたは上場企業のIR・ESG開示に特化したRAGアシスタントです。"
"回答は常に根拠(文書名・ページ)を箇条書きで示し、文書外の推測や断定は避けます。"
"数値は年度と単位を明記し、最新年度を優先してください。"
),
},
"languages": {
"preferred": ["ja", "en", "zh", "ko", "fr", "de", "es", "it"],
"labels": {
"ja": "日本語", "en": "English", "zh": "中文", "ko": "한국어",
"fr": "Français", "de": "Deutsch", "es": "Español", "it": "Italiano"
},
},
}
CFG_ERR = None
CFG_PATH = "config.yaml"
try:
if os.path.exists(CFG_PATH):
with open(CFG_PATH, encoding="utf-8") as f:
CFG = yaml.safe_load(f) or {}
def _merge(dst, src):
for k, v in src.items():
if k not in dst:
dst[k] = v
_merge(CFG, DEFAULT_CFG)
for sec in ("chunk", "retrieval", "llm", "languages"):
if sec in DEFAULT_CFG:
if sec not in CFG or not isinstance(CFG[sec], dict):
CFG[sec] = DEFAULT_CFG[sec]
else:
_merge(CFG[sec], DEFAULT_CFG[sec])
else:
CFG = DEFAULT_CFG
CFG_ERR = "config.yaml が見つかりません。デフォルト設定で起動しました。"
except Exception as e:
CFG = DEFAULT_CFG
CFG_ERR = "config.yaml 読み込みエラー: " + str(e)
# ===== absolute paths =====
BASE_DIR = pathlib.Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
INDEX_DIR = DATA_DIR / "index"
PDF_DIR = DATA_DIR / "pdf"
INDEX_PATH = INDEX_DIR / "faiss.index"
META_PATH = INDEX_DIR / "meta.jsonl"
PDF_DIR.mkdir(parents=True, exist_ok=True)
INDEX_DIR.mkdir(parents=True, exist_ok=True)
# ===== lazy imports =====
def _lazy_imports():
global faiss, np, embed_texts, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT
import faiss, numpy as np
from openai_client import embed_texts, chat
from guardrails import detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT
return faiss, np, embed_texts, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT
def _index_exists(): return INDEX_PATH.exists() and META_PATH.exists()
def _check_api_key(): return bool(os.getenv("OPENAI_API_KEY"))
# ===== retrieval helpers =====
_INDEX = None
_METAS = None
def _ensure_index_loaded():
global _INDEX, _METAS
if _INDEX is not None and _METAS is not None:
return
if not _index_exists():
raise RuntimeError("index_not_ready")
faiss, *_ = _lazy_imports()
_INDEX = faiss.read_index(str(INDEX_PATH))
_METAS = [json.loads(l) for l in open(META_PATH, encoding="utf-8")]
def _embed_query(q: str):
_, np, embed_texts, *_ = _lazy_imports()
v = np.array(embed_texts([q], CFG["embedding_model"])[0], dtype="float32")
v = v / (np.linalg.norm(v) + 1e-12)
return v[None, :]
def _search(q: str):
faiss, np, *_ = _lazy_imports()
_ensure_index_loaded()
TOP_K = CFG["retrieval"]["top_k"]
SCORETH = CFG["retrieval"]["score_threshold"]
qv = _embed_query(q)
sims, idxs = _INDEX.search(qv, TOP_K * 4)
sims, idxs = sims[0], idxs[0]
picked, seen = [], set()
for score, idx in zip(sims, idxs):
if score < SCORETH: continue
c = _METAS[idx]; key = (c["source"], c["page"])
if key in seen: continue
seen.add(key); picked.append({**c, "score": float(score)})
if len(picked) >= TOP_K: break
return picked
def _format_context(chunks: List[Dict]) -> str:
lines = []
for c in chunks:
snippet = c["text"][:180].replace("\n", " ")
lines.append(f"- 出典: {c['source']} p.{c['page']} | 抜粋: {snippet}…")
return "\n".join(lines)
# ===== QA =====
_LANG_INSTRUCTIONS = {
"ja": "回答は日本語で出力してください。",
"en": "Answer in English.",
"zh": "请用中文回答。",
"ko": "한국어로 답변하세요.",
"fr": "Répondez en français.",
"de": "Bitte auf Deutsch antworten.",
"es": "Responde en español.",
"it": "Rispondi in italiano.",
}
def generate_answer(q: str, lang: str = "ja"):
q = (q or "").strip()
if not q:
return "質問を入力してください。", {}
try:
_, _, _, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT = _lazy_imports()
if detect_out_of_scope(q):
return f"{SCOPE_HINT}\nIR/ESG関連の事項についてお尋ねください。", {}
chunks = _search(q)
context = _format_context(chunks)
lang_note = _LANG_INSTRUCTIONS.get(lang, "Answer in the user's language.")
user_prompt = (
"以下のコンテキストのみを根拠に、簡潔かつ正確に回答してください。\n"
"必ず箇条書きで根拠(文書名とページ)を列挙してください。\n"
f"{lang_note}\n\n[コンテキスト]\n{context}\n\n[質問]\n{q}"
)
messages = [
{"role": "system", "content": CFG["llm"]["system_prompt"]},
{"role": "user", "content": user_prompt},
]
text = chat(messages, model=CFG["llm"]["model"],
max_output_tokens=CFG["llm"]["max_output_tokens"],
temperature=CFG["llm"]["temperature"])
text = sanitize(text) + "\n\n" + compliance_block()
meta = {"citations": [{"source": c["source"], "page": c["page"], "score": round(c["score"], 3)} for c in chunks]}
return text, meta
except RuntimeError as e:
if str(e) == "index_not_ready":
return ("⚠️ インデックスがまだありません。\n"
"1) PDFをアップロード\n"
"2) 『インデックス再構築』ボタン または 『アップロード→再構築』を実行(OpenAI APIキー必須)\n"), {}
raise
except Exception as e:
return "❌ 実行時エラー: " + str(e) + "\n" + traceback.format_exc()[-1200:], {}
# ===== Upload & Rebuild (helpers) =====
SAFE_RE = re.compile(r"[^A-Za-z0-9._-]+")
def _safe_name(name: str) -> str:
base = pathlib.Path(name).name
base = SAFE_RE.sub("_", base)
if not base.lower().endswith(".pdf"):
base += ".pdf"
return base
def save_uploaded_pdfs(file_paths: Optional[Iterable[str]]) -> Tuple[int, List[str], List[str]]:
saved, skipped = [], []
if not file_paths:
return 0, saved, ["アップロードされたPDFがありません。"]
try:
PDF_DIR.mkdir(parents=True, exist_ok=True)
except Exception as e:
return 0, [], [f"data/pdf の作成に失敗: {e}"]
for fp in file_paths:
if not fp:
continue
src = pathlib.Path(fp)
if not src.exists():
skipped.append(f"見つかりません: {fp}")
continue
if src.suffix.lower() != ".pdf":
skipped.append(f"PDFではありません: {src.name}")
continue
dst = PDF_DIR / _safe_name(src.name)
try:
# メタデータを持ち越さない(権限エラー回避)
shutil.copyfile(src, dst)
saved.append(str(dst))
except PermissionError:
skipped.append(f"{src.name}: Permission denied({dst})。Dockerfileの所有権設定を確認してください。")
except Exception as e:
skipped.append(f"{src.name}: コピー失敗 ({e})")
return len(saved), saved, skipped
def upload_and_rebuild(file_paths: Optional[List[str]]) -> str:
n, saved, skipped = save_uploaded_pdfs(file_paths)
msg = []
if n > 0:
msg.append(f"📥 {n} 件のPDFを {PDF_DIR} に保存しました。")
msg.extend([f"- {p}" for p in saved[:10]])
if skipped:
msg.append("⚠️ スキップ/エラー:")
msg.extend([f"- {s}" for s in skipped[:10]])
msg.append("\n🔧 インデックス再構築を開始します…")
msg.append(rebuild_index())
return "\n".join(msg)
def rebuild_index() -> str:
if not _check_api_key():
return "OPENAI_API_KEY が未設定です。コンソール / Secrets に登録してください。"
if not list(PDF_DIR.glob("*.pdf")):
return f"{PDF_DIR} にPDFがありません。PDFをアップロードして再実行してください。"
try:
out = subprocess.run([sys.executable, str(BASE_DIR / "ingest.py")],
capture_output=True, text=True, check=True)
global _INDEX, _METAS
_INDEX = None
_METAS = None
return "✅ インデックス生成完了\n```\n" + (out.stdout[-1200:] or "") + "\n```"
except subprocess.CalledProcessError as e:
return f"❌ インデックス生成に失敗\nstdout:\n{e.stdout}\n\nstderr:\n{e.stderr}"
except Exception as e:
return "❌ 予期せぬエラー: " + str(e) + "\n" + traceback.format_exc()[-1200:]
# ===== File-system diagnose =====
def fs_diagnose() -> str:
lines = []
lines.append(f"User: {getpass.getuser()}")
lines.append(f"CWD : {os.getcwd()}")
for p in [BASE_DIR, DATA_DIR, PDF_DIR, INDEX_DIR]:
try:
st = p.stat()
mode = stat.filemode(st.st_mode)
lines.append(f"{p} exists={p.exists()} owner={st.st_uid}:{st.st_gid} mode={mode}")
except Exception as e:
lines.append(f"{p} stat error: {e}")
try:
test = PDF_DIR / "_write_test.tmp"
with open(test, "wb") as f:
f.write(b"ok")
test.unlink()
lines.append("WRITE TEST: OK (data/pdf に書き込み可能)")
except Exception as e:
lines.append(f"WRITE TEST: NG -> {e}")
return "```\n" + "\n".join(lines) + "\n```"
# ===== FastAPI =====
app = FastAPI(title=CFG.get("app_name", "RAG Bot"))
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"]
)
@app.get("/health")
def health(): return {"status": "ok"}
@app.post("/api/answer")
def api_answer(payload: Dict = Body(...)):
text, meta = generate_answer(payload.get("question", ""), payload.get("lang", "ja"))
return JSONResponse({"text": text, **meta})
@app.post("/api/rebuild")
def api_rebuild():
return JSONResponse({"message": rebuild_index()})
# ===== Gradio UI mounted at "/" =====
LANGS = CFG["languages"]["preferred"]
with gr.Blocks(fill_height=True, title=CFG.get("app_name", "RAG Bot")) as demo:
gr.Markdown("# IR・ESG開示RAG(OpenAI API)— 8言語対応")
if CFG_ERR:
gr.Markdown(f"**構成警告**: {CFG_ERR}")
# --- Q&A ---
with gr.Row():
q = gr.Textbox(label="質問 / Question", lines=3,
placeholder="例: 2024年度のGHG排出量(スコープ1-3)は?")
with gr.Row():
lang = gr.Dropdown(choices=LANGS, value=LANGS[0],
label="回答言語 / Output language")
with gr.Row():
ask = gr.Button("回答する / Answer", variant="primary")
rebuild_btn = gr.Button("インデックス再構築(ingest.py 実行)")
ans = gr.Markdown()
cites = gr.JSON(label="根拠メタデータ / Citations")
log = gr.Markdown()
ask.click(fn=generate_answer, inputs=[q, lang], outputs=[ans, cites])
rebuild_btn.click(fn=rebuild_index, outputs=[log])
# --- Upload & Rebuild(Stateで安定化) ---
gr.Markdown("## 📄 PDFアップロード")
with gr.Row():
uploaded_files = gr.State([]) # 一時パス保持
uploads = gr.Files(
label="PDFをドラッグ&ドロップ(複数可)",
file_types=[".pdf"],
type="filepath",
file_count="multiple",
)
def _capture_files(fs: Optional[List[str]]) -> List[str]:
return fs or []
uploads.change(fn=_capture_files, inputs=[uploads], outputs=[uploaded_files])
with gr.Row():
up_btn = gr.Button("アップロード → インデックス再構築", variant="secondary")
diag_btn = gr.Button("📋 ストレージ診断")
up_log = gr.Markdown()
diag_log = gr.Markdown()
# ボタンは State を入力にする
up_btn.click(fn=upload_and_rebuild, inputs=[uploaded_files], outputs=[up_log])
diag_btn.click(fn=fs_diagnose, outputs=[diag_log])
from gradio.routes import mount_gradio_app
mount_gradio_app(app, demo, path="/")