axonembedding / app.py
1MR's picture
Update app.py
3b054bc verified
"""
KG Embedding Server β€” FastAPI on HuggingFace Spaces
Session-based: client syncs once β†’ server holds FAISS index in RAM
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Dict, Any, Optional
import contextlib
from io import StringIO
app = FastAPI(title="KG Embedding Server")
# ══════════════════════════════════════════════════════
# GLOBALS
# ══════════════════════════════════════════════════════
_model = None
_use_st = False
_faiss = None
_np = None
_sessions: Dict[str, Dict[str, Any]] = {}
@app.on_event("startup")
def load_model():
global _model, _use_st, _faiss, _np
# sentence-transformers
try:
from sentence_transformers import SentenceTransformer
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
_model = SentenceTransformer("all-MiniLM-L6-v2")
_use_st = True
print("[Server] sentence-transformers loaded βœ“")
except Exception as e:
print(f"[Server] ST unavailable: {e}")
# faiss
try:
import faiss as _faiss_mod
_faiss = _faiss_mod
print("[Server] faiss loaded βœ“")
except Exception as e:
print(f"[Server] faiss unavailable: {e}")
# numpy
try:
import numpy as np
_np = np
print("[Server] numpy loaded βœ“")
except Exception as e:
print(f"[Server] numpy unavailable: {e}")
# ══════════════════════════════════════════════════════
# REQUEST / RESPONSE MODELS
# ══════════════════════════════════════════════════════
class NodePayload(BaseModel):
id: int
title: str
content: str
class SyncRequest(BaseModel):
session_id: str
nodes: list[NodePayload]
class SyncResponse(BaseModel):
status: str
count: int
class SearchRequest(BaseModel):
session_id: str
query: str
top_k: int = 8
class SearchResult(BaseModel):
node_id: int
score: float
class SearchResponse(BaseModel):
results: list[SearchResult]
session_missing: bool = False
class HealthResponse(BaseModel):
status: str
model_loaded: bool
faiss_loaded: bool
# ══════════════════════════════════════════════════════
# ENDPOINTS
# ══════════════════════════════════════════════════════
@app.get("/health", response_model=HealthResponse)
def health():
return {
"status": "ok",
"model_loaded": _use_st,
"faiss_loaded": _faiss is not None,
}
@app.post("/sync", response_model=SyncResponse)
def sync(req: SyncRequest):
"""
Client uploads all nodes once.
Server generates embeddings + builds FAISS index in RAM.
"""
if not _use_st or _model is None:
raise HTTPException(503, "sentence-transformers not loaded")
if _faiss is None or _np is None:
raise HTTPException(503, "faiss / numpy not loaded")
texts = [f"{n.title} {n.content}" for n in req.nodes]
if not texts:
_sessions[req.session_id] = {"nodes": [], "index": None}
return SyncResponse(status="ok", count=0)
# embed ΩƒΩ„ Ψ§Ω„Ω€ nodes دفعة واحدة Ω…ΨΉ Ψͺطبيع (cosine = dot product Ψ¨ΨΉΨ―ΩŠΩ†)
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
vecs = _np.array(vecs, dtype="float32")
dim = vecs.shape[1]
# IndexFlatIP = exact inner product (= cosine Ψ¨ΨΉΨ― normalize)
index = _faiss.IndexFlatIP(dim)
index.add(vecs)
_sessions[req.session_id] = {
"nodes": [{"id": n.id, "title": n.title} for n in req.nodes],
"index": index,
}
print(f"[Server] /sync β†’ session={req.session_id[:8]}… | {len(texts)} nodes indexed βœ“")
return SyncResponse(status="ok", count=len(texts))
@app.post("/search", response_model=SearchResponse)
def search(req: SearchRequest):
"""
Embed query only β†’ FAISS ANN search against cached index.
Zero candidate transfer per search.
"""
if not _use_st or _model is None:
raise HTTPException(503, "sentence-transformers not loaded")
session = _sessions.get(req.session_id)
if not session or session["index"] is None:
return SearchResponse(results=[], session_missing=True)
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
qvec = _model.encode([req.query], normalize_embeddings=True)
qvec = _np.array(qvec, dtype="float32")
k = min(req.top_k, len(session["nodes"]))
scores, indices = session["index"].search(qvec, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0:
results.append(SearchResult(
node_id=session["nodes"][idx]["id"],
score=float(score),
))
return SearchResponse(results=results)
@app.get("/ping")
def ping():
"""Keep-alive β€” Ψ§ΨΉΩ…Ω„ cron job يبعΨͺΩ‡ ΩƒΩ„ 4 Ψ―Ω‚Ψ§ΩŠΩ‚."""
return {"pong": True}