Corin1998's picture
Update irpr/main.py
2dd13f2 verified
# irpr/main.py
from __future__ import annotations
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
import os, traceback
from irpr.models import IngestRequest, GenerateRequest
from irpr.config import settings
app = FastAPI(title="IR/PR Co-Pilot Pro", version="0.4.5 (OpenAI)")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ===== ディレクトリ選定(deps と同じ戦略) =====
def _ensure_dir_writable(path: str) -> bool:
try:
os.makedirs(path, exist_ok=True)
try:
os.chmod(path, 0o777)
except Exception:
pass
testfile = os.path.join(path, ".write_test")
with open(testfile, "wb") as f:
f.write(b"ok")
os.remove(testfile)
return True
except Exception:
return False
def _ensure_dir_tree(base: str, sub: str = "simple_index") -> bool:
if not _ensure_dir_writable(base):
return False
return _ensure_dir_writable(os.path.join(base, sub))
def _pick_writable_dir() -> str:
candidates = []
if settings.DATA_DIR:
candidates.append(settings.DATA_DIR)
candidates += ["/tmp/irpr", "/mnt/data", os.path.join(os.getcwd(), "data")]
for base in candidates:
if _ensure_dir_tree(base, "simple_index"):
return base
fallback = "/tmp/irpr"
_ensure_dir_tree(fallback, "simple_index")
return fallback
BASE_DIR = _pick_writable_dir()
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
INDEX_DIR = os.path.join(BASE_DIR, "simple_index")
_ensure_dir_writable(UPLOAD_DIR)
_ensure_dir_writable(INDEX_DIR)
# 静的配信
os.makedirs("static", exist_ok=True)
os.makedirs("templates", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/files", StaticFiles(directory=BASE_DIR), name="files")
templates = Jinja2Templates(directory="templates")
# ===== UI / Health =====
@app.get("/", response_class=HTMLResponse)
def ui(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "base_dir": BASE_DIR})
@app.get("/api/health")
def health():
return {
"ok": True,
"service": "IR/PR Co-Pilot Pro",
"base_dir": BASE_DIR,
"upload_dir": UPLOAD_DIR,
"index_dir": INDEX_DIR,
"data_dir_env": settings.DATA_DIR,
"index_dir_env": settings.INDEX_DIR,
"upload_writable": _ensure_dir_writable(UPLOAD_DIR),
"index_writable": _ensure_dir_writable(INDEX_DIR),
}
# ===== ingest =====
@app.post("/ingest/edinet")
def ingest_edinet(req: IngestRequest):
from rag.ingest import ingest_edinet_for_company
try:
n = ingest_edinet_for_company(req.edinet_code, req.date)
return {"ok": True, "ingested_chunks": n}
except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))[-2000:]
return {"ok": False, "error": repr(e), "trace": tb}
def _safe_filename(name: str) -> str:
name = os.path.basename(name or "").strip()
if not name:
name = "upload.pdf"
return name
@app.post("/ingest/upload")
async def ingest_upload(files: list[UploadFile] = File(...)):
from rag.ingest import ingest_pdf_bytes
total = 0; saved = []
try:
if not _ensure_dir_writable(UPLOAD_DIR):
raise PermissionError(f"UPLOAD_DIR not writable: {UPLOAD_DIR}")
if not files:
return {"ok": False, "error": "no files", "ingested_chunks": 0, "saved": [], "base_dir": BASE_DIR}
for f in files:
fname = _safe_filename(f.filename)
if not fname.lower().endswith(".pdf"):
continue
blob = await f.read()
if not blob:
return {"ok": False, "error": f"empty file: {fname}", "ingested_chunks": 0, "saved": [], "base_dir": BASE_DIR}
os.makedirs(UPLOAD_DIR, exist_ok=True)
try:
save_path = os.path.join(UPLOAD_DIR, fname)
with open(save_path, "wb") as w:
w.write(blob)
except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))[-2000:]
return {"ok": False, "error": f"{repr(e)} at path={os.path.join(UPLOAD_DIR, fname)}", "trace": tb,
"ingested_chunks": total, "saved": saved, "base_dir": BASE_DIR}
source_url = f"/files/uploads/{fname}"
try:
added = ingest_pdf_bytes(title=fname, source_url=source_url, pdf_bytes=blob)
except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))[-2000:]
return {"ok": False, "error": f"{repr(e)} while ingesting file={fname}",
"trace": tb, "ingested_chunks": total, "saved": saved, "base_dir": BASE_DIR}
total += added
saved.append(source_url)
return {"ok": True, "ingested_chunks": total, "saved": saved, "base_dir": BASE_DIR}
except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))[-2000:]
return {"ok": False, "error": repr(e), "trace": tb, "ingested_chunks": total, "saved": saved, "base_dir": BASE_DIR}
# ===== generate =====
@app.post("/generate/all")
def generate_all(req: GenerateRequest):
# python-pptx 未導入時は明確に案内
try:
from export.ppt import build_deck, save_pptx
except ModuleNotFoundError:
return {
"ok": False,
"error": "python-pptx が未インストールのため PPTX を生成できません。requirements.txt に `python-pptx`, `lxml`, `Pillow` を追加して再起動してください。",
"base_dir": BASE_DIR
}
try:
from generators.summary import make_summary
from generators.qa import make_qa
from export.qa_csv import save_qa_csv
summary_text, links = make_summary(req.query)
sections = {
"highlights": _extract_section(summary_text, "業績ハイライト"),
"outlook": _extract_section(summary_text, "見通し"),
"segments": _extract_section(summary_text, "セグメント"),
"finance": _extract_section(summary_text, "財務"),
"shareholder":_extract_section(summary_text, "株主還元"),
"esg": _extract_section(summary_text, "ESG"),
"risks": _extract_section(summary_text, "リスク"),
}
qa_list, links2 = make_qa(req.query, 30)
ppt_path = os.path.join(BASE_DIR, "ir_summary.pptx")
csv_path = os.path.join(BASE_DIR, "qa_30.csv")
prs = build_deck(sections, links)
save_pptx(prs, ppt_path)
save_qa_csv(qa_list, links2, csv_path)
def to_url(p):
rel = os.path.relpath(p, BASE_DIR).replace("\\", "/")
return f"/files/{rel}"
return {
"ok": True,
"pptx": to_url(ppt_path),
"qa_csv": to_url(csv_path),
"links": list(dict.fromkeys((links or []) + (links2 or []))),
"base_dir": BASE_DIR
}
except Exception as e:
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))[-2000:]
return {"ok": False, "error": repr(e), "trace": tb, "base_dir": BASE_DIR}
def _extract_section(text: str, head: str):
import re
pat = rf"{head}[::]\s*(.*?)(?:\n[^\n]*[::]|\Z)"
m = re.search(pat, text, re.S)
return (m.group(1).strip() if m else "").strip()
from starlette.requests import Request as SRequest
@app.exception_handler(Exception)
async def _all_exception_handler(request: SRequest, exc: Exception):
tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))[-2000:]
return JSONResponse(status_code=200, content={"ok": False, "error": repr(exc), "trace": tb, "base_dir": BASE_DIR})