Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,22 @@
|
|
| 1 |
"""
|
| 2 |
KG Embedding Server β FastAPI on HuggingFace Spaces
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
from fastapi import FastAPI
|
| 7 |
from pydantic import BaseModel
|
| 8 |
-
from typing import Optional
|
| 9 |
-
import math
|
| 10 |
-
import re
|
| 11 |
-
import contextlib
|
| 12 |
from io import StringIO
|
| 13 |
|
| 14 |
app = FastAPI(title="KG Embedding Server")
|
| 15 |
|
| 16 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
-
# GLOBALS
|
| 18 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
_model = None
|
| 20 |
_use_st = False
|
|
|
|
| 21 |
|
| 22 |
@app.on_event("startup")
|
| 23 |
def load_model():
|
|
@@ -30,24 +29,35 @@ def load_model():
|
|
| 30 |
print("[Server] sentence-transformers loaded β")
|
| 31 |
except Exception as e:
|
| 32 |
print(f"[Server] ST unavailable: {e}")
|
| 33 |
-
_use_st = False
|
| 34 |
|
| 35 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
-
#
|
| 37 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
-
class
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
class
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
query: str
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
class
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
class HealthResponse(BaseModel):
|
| 53 |
status: str
|
|
@@ -61,37 +71,66 @@ class HealthResponse(BaseModel):
|
|
| 61 |
def health():
|
| 62 |
return {"status": "ok", "model_loaded": _use_st}
|
| 63 |
|
| 64 |
-
@app.post("/
|
| 65 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if not _use_st or _model is None:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
|
| 71 |
-
vecs = _model.encode(
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
@app.post("/similarity", response_model=SimilarityResponse)
|
| 75 |
-
def similarity(req: SimilarityRequest):
|
| 76 |
-
"""Query + candidates β cosine scores (Ψ§ΩΨ£Ψ³Ψ±ΨΉ ΩΩ ΨΉΨ§ΩΨ² ΨͺΨ±ΨͺΨ¨ ΩΨͺΨ§ΩΨ¬)"""
|
| 77 |
-
if not _use_st or _model is None:
|
| 78 |
-
raise HTTPException(503, "Model not loaded")
|
| 79 |
-
texts = [req.query] + req.candidates
|
| 80 |
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
def _cosine(a, b) -> float:
|
| 87 |
-
dot = sum(x * y for x, y in zip(a, b))
|
| 88 |
-
na = math.sqrt(sum(x * x for x in a)) or 1e-9
|
| 89 |
-
nb = math.sqrt(sum(x * x for x in b)) or 1e-9
|
| 90 |
-
return dot / (na * nb)
|
| 91 |
|
| 92 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 93 |
-
# PING β ΩΩ
ΩΨΉ Ψ§ΩΩ Space Ω
Ω Ψ§ΩΩΩΩ
(Ψ§ΨΉΩ
Ω cron job ΩΨ¨ΨΉΨͺΩ ΩΩ 4 Ψ―ΩΨ§ΩΩ)
|
| 94 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 95 |
@app.get("/ping")
|
| 96 |
def ping():
|
| 97 |
return {"pong": True}
|
|
|
|
| 1 |
"""
|
| 2 |
KG Embedding Server β FastAPI on HuggingFace Spaces
|
| 3 |
+
Session-based: client syncs once β server holds FAISS index in RAM
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
from pydantic import BaseModel
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
import math, contextlib, uuid
|
|
|
|
|
|
|
| 10 |
from io import StringIO
|
| 11 |
|
| 12 |
app = FastAPI(title="KG Embedding Server")
|
| 13 |
|
| 14 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
# GLOBALS
|
| 16 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
_model = None
|
| 18 |
_use_st = False
|
| 19 |
+
_sessions: Dict[str, Dict[str, Any]] = {} # session_id β {nodes, index}
|
| 20 |
|
| 21 |
@app.on_event("startup")
|
| 22 |
def load_model():
|
|
|
|
| 29 |
print("[Server] sentence-transformers loaded β")
|
| 30 |
except Exception as e:
|
| 31 |
print(f"[Server] ST unavailable: {e}")
|
|
|
|
| 32 |
|
| 33 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
+
# MODELS
|
| 35 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
class NodePayload(BaseModel):
|
| 37 |
+
id: int
|
| 38 |
+
title: str
|
| 39 |
+
content: str
|
| 40 |
|
| 41 |
+
class SyncRequest(BaseModel):
|
| 42 |
+
session_id: str
|
| 43 |
+
nodes: list[NodePayload]
|
| 44 |
|
| 45 |
+
class SyncResponse(BaseModel):
|
| 46 |
+
status: str
|
| 47 |
+
count: int
|
| 48 |
+
|
| 49 |
+
class SearchRequest(BaseModel):
|
| 50 |
+
session_id: str
|
| 51 |
query: str
|
| 52 |
+
top_k: int = 8
|
| 53 |
+
|
| 54 |
+
class SearchResult(BaseModel):
|
| 55 |
+
node_id: int
|
| 56 |
+
score: float
|
| 57 |
|
| 58 |
+
class SearchResponse(BaseModel):
|
| 59 |
+
results: list[SearchResult]
|
| 60 |
+
session_missing: bool = False
|
| 61 |
|
| 62 |
class HealthResponse(BaseModel):
|
| 63 |
status: str
|
|
|
|
| 71 |
def health():
|
| 72 |
return {"status": "ok", "model_loaded": _use_st}
|
| 73 |
|
| 74 |
+
@app.post("/sync", response_model=SyncResponse)
|
| 75 |
+
def sync(req: SyncRequest):
|
| 76 |
+
"""Client uploads nodes once β server builds FAISS index in RAM."""
|
| 77 |
+
try:
|
| 78 |
+
import faiss
|
| 79 |
+
import numpy as np
|
| 80 |
+
except ImportError:
|
| 81 |
+
raise Exception("faiss not installed on server")
|
| 82 |
+
|
| 83 |
if not _use_st or _model is None:
|
| 84 |
+
return SyncResponse(status="no_model", count=0)
|
| 85 |
+
|
| 86 |
+
texts = [f"{n.title} {n.content}" for n in req.nodes]
|
| 87 |
+
if not texts:
|
| 88 |
+
_sessions[req.session_id] = {"nodes": [], "index": None}
|
| 89 |
+
return SyncResponse(status="ok", count=0)
|
| 90 |
+
|
| 91 |
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
|
| 92 |
+
vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
|
| 93 |
+
|
| 94 |
+
vecs = np.array(vecs, dtype="float32")
|
| 95 |
+
dim = vecs.shape[1]
|
| 96 |
+
|
| 97 |
+
index = faiss.IndexFlatIP(dim) # cosine sim (vectors already normalized)
|
| 98 |
+
index.add(vecs)
|
| 99 |
+
|
| 100 |
+
_sessions[req.session_id] = {
|
| 101 |
+
"nodes": [n.dict() for n in req.nodes],
|
| 102 |
+
"index": index,
|
| 103 |
+
}
|
| 104 |
+
print(f"[Server] Session {req.session_id[:8]}β¦ β {len(texts)} nodes indexed β")
|
| 105 |
+
return SyncResponse(status="ok", count=len(texts))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.post("/search", response_model=SearchResponse)
|
| 109 |
+
def search(req: SearchRequest):
|
| 110 |
+
"""Embed query only, FAISS search against cached index."""
|
| 111 |
+
import numpy as np
|
| 112 |
+
|
| 113 |
+
session = _sessions.get(req.session_id)
|
| 114 |
+
if not session or session["index"] is None:
|
| 115 |
+
return SearchResponse(results=[], session_missing=True)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
|
| 118 |
+
qvec = _model.encode([req.query], normalize_embeddings=True)
|
| 119 |
+
|
| 120 |
+
qvec = np.array(qvec, dtype="float32")
|
| 121 |
+
k = min(req.top_k, len(session["nodes"]))
|
| 122 |
+
scores, indices = session["index"].search(qvec, k)
|
| 123 |
+
|
| 124 |
+
results = []
|
| 125 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 126 |
+
if idx >= 0:
|
| 127 |
+
results.append(SearchResult(
|
| 128 |
+
node_id=session["nodes"][idx]["id"],
|
| 129 |
+
score=float(score)
|
| 130 |
+
))
|
| 131 |
+
return SearchResponse(results=results)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
|
|
|
|
|
|
|
|
|
| 134 |
@app.get("/ping")
|
| 135 |
def ping():
|
| 136 |
return {"pong": True}
|