""" KG Embedding Server — FastAPI on HuggingFace Spaces Session-based: client syncs once → server holds FAISS index in RAM """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Dict, Any, Optional import contextlib from io import StringIO app = FastAPI(title="KG Embedding Server") # ══════════════════════════════════════════════════════ # GLOBALS # ══════════════════════════════════════════════════════ _model = None _use_st = False _faiss = None _np = None _sessions: Dict[str, Dict[str, Any]] = {} @app.on_event("startup") def load_model(): global _model, _use_st, _faiss, _np # sentence-transformers try: from sentence_transformers import SentenceTransformer with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): _model = SentenceTransformer("all-MiniLM-L6-v2") _use_st = True print("[Server] sentence-transformers loaded ✓") except Exception as e: print(f"[Server] ST unavailable: {e}") # faiss try: import faiss as _faiss_mod _faiss = _faiss_mod print("[Server] faiss loaded ✓") except Exception as e: print(f"[Server] faiss unavailable: {e}") # numpy try: import numpy as np _np = np print("[Server] numpy loaded ✓") except Exception as e: print(f"[Server] numpy unavailable: {e}") # ══════════════════════════════════════════════════════ # REQUEST / RESPONSE MODELS # ══════════════════════════════════════════════════════ class NodePayload(BaseModel): id: int title: str content: str class SyncRequest(BaseModel): session_id: str nodes: list[NodePayload] class SyncResponse(BaseModel): status: str count: int class SearchRequest(BaseModel): session_id: str query: str top_k: int = 8 class SearchResult(BaseModel): node_id: int score: float class SearchResponse(BaseModel): results: list[SearchResult] session_missing: bool = False class HealthResponse(BaseModel): status: str model_loaded: bool faiss_loaded: bool # ══════════════════════════════════════════════════════ # ENDPOINTS # ══════════════════════════════════════════════════════ @app.get("/health", response_model=HealthResponse) def health(): return { "status": "ok", "model_loaded": _use_st, "faiss_loaded": _faiss is not None, } @app.post("/sync", response_model=SyncResponse) def sync(req: SyncRequest): """ Client uploads all nodes once. Server generates embeddings + builds FAISS index in RAM. """ if not _use_st or _model is None: raise HTTPException(503, "sentence-transformers not loaded") if _faiss is None or _np is None: raise HTTPException(503, "faiss / numpy not loaded") texts = [f"{n.title} {n.content}" for n in req.nodes] if not texts: _sessions[req.session_id] = {"nodes": [], "index": None} return SyncResponse(status="ok", count=0) # embed كل الـ nodes دفعة واحدة مع تطبيع (cosine = dot product بعدين) with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True) vecs = _np.array(vecs, dtype="float32") dim = vecs.shape[1] # IndexFlatIP = exact inner product (= cosine بعد normalize) index = _faiss.IndexFlatIP(dim) index.add(vecs) _sessions[req.session_id] = { "nodes": [{"id": n.id, "title": n.title} for n in req.nodes], "index": index, } print(f"[Server] /sync → session={req.session_id[:8]}… | {len(texts)} nodes indexed ✓") return SyncResponse(status="ok", count=len(texts)) @app.post("/search", response_model=SearchResponse) def search(req: SearchRequest): """ Embed query only → FAISS ANN search against cached index. Zero candidate transfer per search. """ if not _use_st or _model is None: raise HTTPException(503, "sentence-transformers not loaded") session = _sessions.get(req.session_id) if not session or session["index"] is None: return SearchResponse(results=[], session_missing=True) with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): qvec = _model.encode([req.query], normalize_embeddings=True) qvec = _np.array(qvec, dtype="float32") k = min(req.top_k, len(session["nodes"])) scores, indices = session["index"].search(qvec, k) results = [] for score, idx in zip(scores[0], indices[0]): if idx >= 0: results.append(SearchResult( node_id=session["nodes"][idx]["id"], score=float(score), )) return SearchResponse(results=results) @app.get("/ping") def ping(): """Keep-alive — اعمل cron job يبعته كل 4 دقايق.""" return {"pong": True}