Spaces:
Running
Running
| """ | |
| KG Embedding Server β FastAPI on HuggingFace Spaces | |
| Session-based: client syncs once β server holds FAISS index in RAM | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, Optional | |
| import contextlib | |
| from io import StringIO | |
| app = FastAPI(title="KG Embedding Server") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GLOBALS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| _use_st = False | |
| _faiss = None | |
| _np = None | |
| _sessions: Dict[str, Dict[str, Any]] = {} | |
| def load_model(): | |
| global _model, _use_st, _faiss, _np | |
| # sentence-transformers | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): | |
| _model = SentenceTransformer("all-MiniLM-L6-v2") | |
| _use_st = True | |
| print("[Server] sentence-transformers loaded β") | |
| except Exception as e: | |
| print(f"[Server] ST unavailable: {e}") | |
| # faiss | |
| try: | |
| import faiss as _faiss_mod | |
| _faiss = _faiss_mod | |
| print("[Server] faiss loaded β") | |
| except Exception as e: | |
| print(f"[Server] faiss unavailable: {e}") | |
| # numpy | |
| try: | |
| import numpy as np | |
| _np = np | |
| print("[Server] numpy loaded β") | |
| except Exception as e: | |
| print(f"[Server] numpy unavailable: {e}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # REQUEST / RESPONSE MODELS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class NodePayload(BaseModel): | |
| id: int | |
| title: str | |
| content: str | |
| class SyncRequest(BaseModel): | |
| session_id: str | |
| nodes: list[NodePayload] | |
| class SyncResponse(BaseModel): | |
| status: str | |
| count: int | |
| class SearchRequest(BaseModel): | |
| session_id: str | |
| query: str | |
| top_k: int = 8 | |
| class SearchResult(BaseModel): | |
| node_id: int | |
| score: float | |
| class SearchResponse(BaseModel): | |
| results: list[SearchResult] | |
| session_missing: bool = False | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| faiss_loaded: bool | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENDPOINTS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": _use_st, | |
| "faiss_loaded": _faiss is not None, | |
| } | |
| def sync(req: SyncRequest): | |
| """ | |
| Client uploads all nodes once. | |
| Server generates embeddings + builds FAISS index in RAM. | |
| """ | |
| if not _use_st or _model is None: | |
| raise HTTPException(503, "sentence-transformers not loaded") | |
| if _faiss is None or _np is None: | |
| raise HTTPException(503, "faiss / numpy not loaded") | |
| texts = [f"{n.title} {n.content}" for n in req.nodes] | |
| if not texts: | |
| _sessions[req.session_id] = {"nodes": [], "index": None} | |
| return SyncResponse(status="ok", count=0) | |
| # embed ΩΩ Ψ§ΩΩ nodes Ψ―ΩΨΉΨ© ΩΨ§ΨΨ―Ψ© Ω ΨΉ ΨͺΨ·Ψ¨ΩΨΉ (cosine = dot product Ψ¨ΨΉΨ―ΩΩ) | |
| with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): | |
| vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True) | |
| vecs = _np.array(vecs, dtype="float32") | |
| dim = vecs.shape[1] | |
| # IndexFlatIP = exact inner product (= cosine Ψ¨ΨΉΨ― normalize) | |
| index = _faiss.IndexFlatIP(dim) | |
| index.add(vecs) | |
| _sessions[req.session_id] = { | |
| "nodes": [{"id": n.id, "title": n.title} for n in req.nodes], | |
| "index": index, | |
| } | |
| print(f"[Server] /sync β session={req.session_id[:8]}β¦ | {len(texts)} nodes indexed β") | |
| return SyncResponse(status="ok", count=len(texts)) | |
| def search(req: SearchRequest): | |
| """ | |
| Embed query only β FAISS ANN search against cached index. | |
| Zero candidate transfer per search. | |
| """ | |
| if not _use_st or _model is None: | |
| raise HTTPException(503, "sentence-transformers not loaded") | |
| session = _sessions.get(req.session_id) | |
| if not session or session["index"] is None: | |
| return SearchResponse(results=[], session_missing=True) | |
| with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()): | |
| qvec = _model.encode([req.query], normalize_embeddings=True) | |
| qvec = _np.array(qvec, dtype="float32") | |
| k = min(req.top_k, len(session["nodes"])) | |
| scores, indices = session["index"].search(qvec, k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx >= 0: | |
| results.append(SearchResult( | |
| node_id=session["nodes"][idx]["id"], | |
| score=float(score), | |
| )) | |
| return SearchResponse(results=results) | |
| def ping(): | |
| """Keep-alive β Ψ§ΨΉΩ Ω cron job ΩΨ¨ΨΉΨͺΩ ΩΩ 4 Ψ―ΩΨ§ΩΩ.""" | |
| return {"pong": True} |