deepdive-IR / app.py
Ritabanm's picture
Update app.py
100da6f verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict
import os, shutil, logging, traceback
from pathlib import Path
# ===== Persistent storage on HF =====
DATA_DIR = os.getenv("DATA_DIR", "/data")
INDEX_ROOT = os.getenv("INDEX_DIR", os.path.join(DATA_DIR, "index"))
Path(INDEX_ROOT).mkdir(parents=True, exist_ok=True)
from agent.graph import AgentGraph
from agent.tools import FetchTools
from ingest.sec import fetch_recent_filings_by_cik
log = logging.getLogger("uvicorn.error")
app = FastAPI(title="DeepDive IR Agent")
# ===== CORS: localhost + vercel previews + your prod app =====
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://deepdive-ir-agent.vercel.app", # <-- change if your prod URL is different
],
allow_origin_regex=r"https://.*\.vercel\.app$",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ===== Models =====
class IngestRequest(BaseModel):
cik: str
ir_url: Optional[str] = None
class AskRequest(BaseModel):
question: str
cik: Optional[str] = None # optional; if omitted we use the last ingested CIK
# ===== State =====
graphs: Dict[str, AgentGraph] = {} # CIK -> AgentGraph
last_cik: Optional[str] = None
tools = FetchTools()
# ===== Helpers =====
def norm_cik(raw: str) -> str:
s = raw.strip()
if not s.isdigit():
raise HTTPException(400, "CIK must be digits only.")
if len(s) > 10:
raise HTTPException(400, "CIK too long; use 10 digits.")
return s.zfill(10)
def idx_dir_for(cik: str) -> str:
d = os.path.join(INDEX_ROOT, cik)
Path(d).mkdir(parents=True, exist_ok=True)
return d
# ===== Routes =====
@app.get("/")
def root():
return {"ok": True, "msg": "DeepDive IR Agent API"}
@app.get("/healthz")
def healthz():
return {"ok": True}
@app.post("/ingest")
async def ingest(req: IngestRequest):
"""
Build a fresh index for this CIK under /data/index/<CIK>.
"""
global last_cik
try:
cik = norm_cik(req.cik)
# Fetch recent filings + optional IR site
filings = await fetch_recent_filings_by_cik(cik)
docs = []
for form, url, title in filings:
try:
text = await tools.get_text_from_url(url)
except Exception as e:
log.warning(f"Fetch failed for {url}: {e}")
text = ""
if text:
docs.append({"title": title, "url": url, "text": text})
if req.ir_url:
try:
ir_text = await tools.get_text_from_url(req.ir_url)
if ir_text:
docs.append({"title": "IR site", "url": req.ir_url, "text": ir_text})
except Exception as e:
log.warning(f"IR fetch failed for {req.ir_url}: {e}")
if not docs:
raise HTTPException(400, "No documents fetched.")
# Fresh per-CIK folder
idx_dir = idx_dir_for(cik)
shutil.rmtree(idx_dir, ignore_errors=True)
Path(idx_dir).mkdir(parents=True, exist_ok=True)
# Some libs may write relative paths like "index/vecs.npy".
# Build from the CIK directory so relative paths resolve to /data/index/<CIK>/...
prev = os.getcwd()
os.chdir(idx_dir)
try:
g = AgentGraph(index_dir=idx_dir) # absolute per-CIK dir
g.build_index(docs)
finally:
os.chdir(prev)
graphs[cik] = g
last_cik = cik
return {"ok": True, "cik": cik, "num_docs": len(docs)}
except HTTPException:
raise
except Exception as e:
log.error("Ingest failed: %s\n%s", e, traceback.format_exc())
raise HTTPException(status_code=502, detail=f"Ingest failed: {type(e).__name__}: {e}")
@app.post("/ask")
def ask(req: AskRequest):
try:
cik = norm_cik(req.cik) if req.cik else last_cik
if not cik or cik not in graphs:
raise HTTPException(400, "No index available. Call /ingest with a CIK first.")
return graphs[cik].answer(req.question)
except HTTPException:
raise
except Exception as e:
log.error("Ask failed: %s\n%s", e, traceback.format_exc())
raise HTTPException(500, detail=str(e))
@app.get("/brief")
def brief(cik: Optional[str] = None):
try:
c = norm_cik(cik) if cik else last_cik
if not c or c not in graphs:
raise HTTPException(400, "No index available. Call /ingest with a CIK first.")
return graphs[c].brief()
except HTTPException:
raise
except Exception as e:
log.error("Brief failed: %s\n%s", e, traceback.format_exc())
raise HTTPException(500, detail=str(e))