Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os, json, yaml, subprocess, sys, pathlib, traceback, shutil, re, getpass, stat | |
| from typing import List, Dict, Tuple, Iterable, Optional | |
| from fastapi import FastAPI, Body | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import gradio as gr | |
| # ===== config ===== | |
| DEFAULT_CFG = { | |
| "app_name": "IR/ESG RAG Bot (OpenAI, 8 languages)", | |
| "embedding_model": "text-embedding-3-large", | |
| "normalize_embeddings": True, | |
| "chunk": {"target_chars": 1400, "overlap_chars": 180}, | |
| "retrieval": {"top_k": 6, "score_threshold": 0.15, "mmr_lambda": 0.3}, | |
| "llm": { | |
| "model": "gpt-4o-mini", # 必要に応じて利用可能なモデルに変更 | |
| "max_output_tokens": 700, | |
| "temperature": 0.2, | |
| "system_prompt": ( | |
| "あなたは上場企業のIR・ESG開示に特化したRAGアシスタントです。" | |
| "回答は常に根拠(文書名・ページ)を箇条書きで示し、文書外の推測や断定は避けます。" | |
| "数値は年度と単位を明記し、最新年度を優先してください。" | |
| ), | |
| }, | |
| "languages": { | |
| "preferred": ["ja", "en", "zh", "ko", "fr", "de", "es", "it"], | |
| "labels": { | |
| "ja": "日本語", "en": "English", "zh": "中文", "ko": "한국어", | |
| "fr": "Français", "de": "Deutsch", "es": "Español", "it": "Italiano" | |
| }, | |
| }, | |
| } | |
| CFG_ERR = None | |
| CFG_PATH = "config.yaml" | |
| try: | |
| if os.path.exists(CFG_PATH): | |
| with open(CFG_PATH, encoding="utf-8") as f: | |
| CFG = yaml.safe_load(f) or {} | |
| def _merge(dst, src): | |
| for k, v in src.items(): | |
| if k not in dst: | |
| dst[k] = v | |
| _merge(CFG, DEFAULT_CFG) | |
| for sec in ("chunk", "retrieval", "llm", "languages"): | |
| if sec in DEFAULT_CFG: | |
| if sec not in CFG or not isinstance(CFG[sec], dict): | |
| CFG[sec] = DEFAULT_CFG[sec] | |
| else: | |
| _merge(CFG[sec], DEFAULT_CFG[sec]) | |
| else: | |
| CFG = DEFAULT_CFG | |
| CFG_ERR = "config.yaml が見つかりません。デフォルト設定で起動しました。" | |
| except Exception as e: | |
| CFG = DEFAULT_CFG | |
| CFG_ERR = "config.yaml 読み込みエラー: " + str(e) | |
| # ===== absolute paths ===== | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| DATA_DIR = BASE_DIR / "data" | |
| INDEX_DIR = DATA_DIR / "index" | |
| PDF_DIR = DATA_DIR / "pdf" | |
| INDEX_PATH = INDEX_DIR / "faiss.index" | |
| META_PATH = INDEX_DIR / "meta.jsonl" | |
| PDF_DIR.mkdir(parents=True, exist_ok=True) | |
| INDEX_DIR.mkdir(parents=True, exist_ok=True) | |
| # ===== lazy imports ===== | |
| def _lazy_imports(): | |
| global faiss, np, embed_texts, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT | |
| import faiss, numpy as np | |
| from openai_client import embed_texts, chat | |
| from guardrails import detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT | |
| return faiss, np, embed_texts, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT | |
| def _index_exists(): return INDEX_PATH.exists() and META_PATH.exists() | |
| def _check_api_key(): return bool(os.getenv("OPENAI_API_KEY")) | |
| # ===== retrieval helpers ===== | |
| _INDEX = None | |
| _METAS = None | |
| def _ensure_index_loaded(): | |
| global _INDEX, _METAS | |
| if _INDEX is not None and _METAS is not None: | |
| return | |
| if not _index_exists(): | |
| raise RuntimeError("index_not_ready") | |
| faiss, *_ = _lazy_imports() | |
| _INDEX = faiss.read_index(str(INDEX_PATH)) | |
| _METAS = [json.loads(l) for l in open(META_PATH, encoding="utf-8")] | |
| def _embed_query(q: str): | |
| _, np, embed_texts, *_ = _lazy_imports() | |
| v = np.array(embed_texts([q], CFG["embedding_model"])[0], dtype="float32") | |
| v = v / (np.linalg.norm(v) + 1e-12) | |
| return v[None, :] | |
| def _search(q: str): | |
| faiss, np, *_ = _lazy_imports() | |
| _ensure_index_loaded() | |
| TOP_K = CFG["retrieval"]["top_k"] | |
| SCORETH = CFG["retrieval"]["score_threshold"] | |
| qv = _embed_query(q) | |
| sims, idxs = _INDEX.search(qv, TOP_K * 4) | |
| sims, idxs = sims[0], idxs[0] | |
| picked, seen = [], set() | |
| for score, idx in zip(sims, idxs): | |
| if score < SCORETH: continue | |
| c = _METAS[idx]; key = (c["source"], c["page"]) | |
| if key in seen: continue | |
| seen.add(key); picked.append({**c, "score": float(score)}) | |
| if len(picked) >= TOP_K: break | |
| return picked | |
| def _format_context(chunks: List[Dict]) -> str: | |
| lines = [] | |
| for c in chunks: | |
| snippet = c["text"][:180].replace("\n", " ") | |
| lines.append(f"- 出典: {c['source']} p.{c['page']} | 抜粋: {snippet}…") | |
| return "\n".join(lines) | |
| # ===== QA ===== | |
| _LANG_INSTRUCTIONS = { | |
| "ja": "回答は日本語で出力してください。", | |
| "en": "Answer in English.", | |
| "zh": "请用中文回答。", | |
| "ko": "한국어로 답변하세요.", | |
| "fr": "Répondez en français.", | |
| "de": "Bitte auf Deutsch antworten.", | |
| "es": "Responde en español.", | |
| "it": "Rispondi in italiano.", | |
| } | |
| def generate_answer(q: str, lang: str = "ja"): | |
| q = (q or "").strip() | |
| if not q: | |
| return "質問を入力してください。", {} | |
| try: | |
| _, _, _, chat, detect_out_of_scope, sanitize, compliance_block, SCOPE_HINT = _lazy_imports() | |
| if detect_out_of_scope(q): | |
| return f"{SCOPE_HINT}\nIR/ESG関連の事項についてお尋ねください。", {} | |
| chunks = _search(q) | |
| context = _format_context(chunks) | |
| lang_note = _LANG_INSTRUCTIONS.get(lang, "Answer in the user's language.") | |
| user_prompt = ( | |
| "以下のコンテキストのみを根拠に、簡潔かつ正確に回答してください。\n" | |
| "必ず箇条書きで根拠(文書名とページ)を列挙してください。\n" | |
| f"{lang_note}\n\n[コンテキスト]\n{context}\n\n[質問]\n{q}" | |
| ) | |
| messages = [ | |
| {"role": "system", "content": CFG["llm"]["system_prompt"]}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| text = chat(messages, model=CFG["llm"]["model"], | |
| max_output_tokens=CFG["llm"]["max_output_tokens"], | |
| temperature=CFG["llm"]["temperature"]) | |
| text = sanitize(text) + "\n\n" + compliance_block() | |
| meta = {"citations": [{"source": c["source"], "page": c["page"], "score": round(c["score"], 3)} for c in chunks]} | |
| return text, meta | |
| except RuntimeError as e: | |
| if str(e) == "index_not_ready": | |
| return ("⚠️ インデックスがまだありません。\n" | |
| "1) PDFをアップロード\n" | |
| "2) 『インデックス再構築』ボタン または 『アップロード→再構築』を実行(OpenAI APIキー必須)\n"), {} | |
| raise | |
| except Exception as e: | |
| return "❌ 実行時エラー: " + str(e) + "\n" + traceback.format_exc()[-1200:], {} | |
| # ===== Upload & Rebuild (helpers) ===== | |
| SAFE_RE = re.compile(r"[^A-Za-z0-9._-]+") | |
| def _safe_name(name: str) -> str: | |
| base = pathlib.Path(name).name | |
| base = SAFE_RE.sub("_", base) | |
| if not base.lower().endswith(".pdf"): | |
| base += ".pdf" | |
| return base | |
| def save_uploaded_pdfs(file_paths: Optional[Iterable[str]]) -> Tuple[int, List[str], List[str]]: | |
| saved, skipped = [], [] | |
| if not file_paths: | |
| return 0, saved, ["アップロードされたPDFがありません。"] | |
| try: | |
| PDF_DIR.mkdir(parents=True, exist_ok=True) | |
| except Exception as e: | |
| return 0, [], [f"data/pdf の作成に失敗: {e}"] | |
| for fp in file_paths: | |
| if not fp: | |
| continue | |
| src = pathlib.Path(fp) | |
| if not src.exists(): | |
| skipped.append(f"見つかりません: {fp}") | |
| continue | |
| if src.suffix.lower() != ".pdf": | |
| skipped.append(f"PDFではありません: {src.name}") | |
| continue | |
| dst = PDF_DIR / _safe_name(src.name) | |
| try: | |
| # メタデータを持ち越さない(権限エラー回避) | |
| shutil.copyfile(src, dst) | |
| saved.append(str(dst)) | |
| except PermissionError: | |
| skipped.append(f"{src.name}: Permission denied({dst})。Dockerfileの所有権設定を確認してください。") | |
| except Exception as e: | |
| skipped.append(f"{src.name}: コピー失敗 ({e})") | |
| return len(saved), saved, skipped | |
| def upload_and_rebuild(file_paths: Optional[List[str]]) -> str: | |
| n, saved, skipped = save_uploaded_pdfs(file_paths) | |
| msg = [] | |
| if n > 0: | |
| msg.append(f"📥 {n} 件のPDFを {PDF_DIR} に保存しました。") | |
| msg.extend([f"- {p}" for p in saved[:10]]) | |
| if skipped: | |
| msg.append("⚠️ スキップ/エラー:") | |
| msg.extend([f"- {s}" for s in skipped[:10]]) | |
| msg.append("\n🔧 インデックス再構築を開始します…") | |
| msg.append(rebuild_index()) | |
| return "\n".join(msg) | |
| def rebuild_index() -> str: | |
| if not _check_api_key(): | |
| return "OPENAI_API_KEY が未設定です。コンソール / Secrets に登録してください。" | |
| if not list(PDF_DIR.glob("*.pdf")): | |
| return f"{PDF_DIR} にPDFがありません。PDFをアップロードして再実行してください。" | |
| try: | |
| out = subprocess.run([sys.executable, str(BASE_DIR / "ingest.py")], | |
| capture_output=True, text=True, check=True) | |
| global _INDEX, _METAS | |
| _INDEX = None | |
| _METAS = None | |
| return "✅ インデックス生成完了\n```\n" + (out.stdout[-1200:] or "") + "\n```" | |
| except subprocess.CalledProcessError as e: | |
| return f"❌ インデックス生成に失敗\nstdout:\n{e.stdout}\n\nstderr:\n{e.stderr}" | |
| except Exception as e: | |
| return "❌ 予期せぬエラー: " + str(e) + "\n" + traceback.format_exc()[-1200:] | |
| # ===== File-system diagnose ===== | |
| def fs_diagnose() -> str: | |
| lines = [] | |
| lines.append(f"User: {getpass.getuser()}") | |
| lines.append(f"CWD : {os.getcwd()}") | |
| for p in [BASE_DIR, DATA_DIR, PDF_DIR, INDEX_DIR]: | |
| try: | |
| st = p.stat() | |
| mode = stat.filemode(st.st_mode) | |
| lines.append(f"{p} exists={p.exists()} owner={st.st_uid}:{st.st_gid} mode={mode}") | |
| except Exception as e: | |
| lines.append(f"{p} stat error: {e}") | |
| try: | |
| test = PDF_DIR / "_write_test.tmp" | |
| with open(test, "wb") as f: | |
| f.write(b"ok") | |
| test.unlink() | |
| lines.append("WRITE TEST: OK (data/pdf に書き込み可能)") | |
| except Exception as e: | |
| lines.append(f"WRITE TEST: NG -> {e}") | |
| return "```\n" + "\n".join(lines) + "\n```" | |
| # ===== FastAPI ===== | |
| app = FastAPI(title=CFG.get("app_name", "RAG Bot")) | |
| app.add_middleware( | |
| CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| def health(): return {"status": "ok"} | |
| def api_answer(payload: Dict = Body(...)): | |
| text, meta = generate_answer(payload.get("question", ""), payload.get("lang", "ja")) | |
| return JSONResponse({"text": text, **meta}) | |
| def api_rebuild(): | |
| return JSONResponse({"message": rebuild_index()}) | |
| # ===== Gradio UI mounted at "/" ===== | |
| LANGS = CFG["languages"]["preferred"] | |
| with gr.Blocks(fill_height=True, title=CFG.get("app_name", "RAG Bot")) as demo: | |
| gr.Markdown("# IR・ESG開示RAG(OpenAI API)— 8言語対応") | |
| if CFG_ERR: | |
| gr.Markdown(f"**構成警告**: {CFG_ERR}") | |
| # --- Q&A --- | |
| with gr.Row(): | |
| q = gr.Textbox(label="質問 / Question", lines=3, | |
| placeholder="例: 2024年度のGHG排出量(スコープ1-3)は?") | |
| with gr.Row(): | |
| lang = gr.Dropdown(choices=LANGS, value=LANGS[0], | |
| label="回答言語 / Output language") | |
| with gr.Row(): | |
| ask = gr.Button("回答する / Answer", variant="primary") | |
| rebuild_btn = gr.Button("インデックス再構築(ingest.py 実行)") | |
| ans = gr.Markdown() | |
| cites = gr.JSON(label="根拠メタデータ / Citations") | |
| log = gr.Markdown() | |
| ask.click(fn=generate_answer, inputs=[q, lang], outputs=[ans, cites]) | |
| rebuild_btn.click(fn=rebuild_index, outputs=[log]) | |
| # --- Upload & Rebuild(Stateで安定化) --- | |
| gr.Markdown("## 📄 PDFアップロード") | |
| with gr.Row(): | |
| uploaded_files = gr.State([]) # 一時パス保持 | |
| uploads = gr.Files( | |
| label="PDFをドラッグ&ドロップ(複数可)", | |
| file_types=[".pdf"], | |
| type="filepath", | |
| file_count="multiple", | |
| ) | |
| def _capture_files(fs: Optional[List[str]]) -> List[str]: | |
| return fs or [] | |
| uploads.change(fn=_capture_files, inputs=[uploads], outputs=[uploaded_files]) | |
| with gr.Row(): | |
| up_btn = gr.Button("アップロード → インデックス再構築", variant="secondary") | |
| diag_btn = gr.Button("📋 ストレージ診断") | |
| up_log = gr.Markdown() | |
| diag_log = gr.Markdown() | |
| # ボタンは State を入力にする | |
| up_btn.click(fn=upload_and_rebuild, inputs=[uploaded_files], outputs=[up_log]) | |
| diag_btn.click(fn=fs_diagnose, outputs=[diag_log]) | |
| from gradio.routes import mount_gradio_app | |
| mount_gradio_app(app, demo, path="/") | |