Spaces:
Sleeping
Sleeping
File size: 2,995 Bytes
398a289 d7bb68c 398a289 d7bb68c 398a289 d7bb68c 398a289 d7bb68c 398a289 d7bb68c 398a289 d7bb68c 398a289 d7bb68c 398a289 0f6ee3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import os
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.models.fusion_inference import FusionClaimVerifier, _resolve_fusion_model_path
# ── Global verifier (pre-warmed at startup) ─────────────────────────────────
_verifier: FusionClaimVerifier | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the model once at container startup so requests never cold-start."""
global _verifier
logger.info("[startup] Pre-warming FusionClaimVerifier …")
try:
fusion_path = _resolve_fusion_model_path(
os.getenv("FUSION_MODEL", "models/fusion_model.pt")
)
_verifier = FusionClaimVerifier(
fusion_model_path=fusion_path,
opensearch_index=os.getenv("OPENSEARCH_INDEX_NAME")
or os.getenv("OP_KB_NAME", "news_kb"),
llm_model_path=os.getenv("LLM_FINETUNE"),
retriever_model_path=os.getenv(
"RETRIEVER_MODEL", "AITeamVN/Vietnamese_Embedding"
),
device=os.getenv("DEVICE", "cpu"),
llm_evidence_top_k=int(os.getenv("FUSION_LLM_EVIDENCE_TOP_K", "3")),
debug=True,
)
logger.info("[startup] FusionClaimVerifier ready ✓")
except Exception:
import traceback
logger.error(f"[startup] Failed to load verifier:\n{traceback.format_exc()}")
# Keep _verifier = None; requests will return a clear error instead of hanging.
yield
# shutdown: nothing to clean up
logger.info("[shutdown] API server stopping.")
app = FastAPI(title="Fake Crypto Claim Detector API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ClaimRequest(BaseModel):
claim: str
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": _verifier is not None}
@app.post("/verify")
def verify_claim(request: ClaimRequest):
if _verifier is None:
import traceback
return {
"verdict": "Lỗi xử lý",
"status": "error",
"error": "Verifier chưa được khởi tạo (xem log startup để biết lý do).",
}
try:
prediction = _verifier.predict(request.claim)
return {"verdict": prediction.verdict, "status": "success"}
except Exception as e:
import traceback
error_traceback = traceback.format_exc()
print(f"API Error: {error_traceback}", flush=True)
return {
"verdict": "Lỗi xử lý",
"status": "error",
"error": str(e),
"traceback": error_traceback,
}
|