1MR commited on
Commit
3b054bc
Β·
verified Β·
1 Parent(s): 6c4cb67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -25
app.py CHANGED
@@ -3,10 +3,10 @@ KG Embedding Server β€” FastAPI on HuggingFace Spaces
3
  Session-based: client syncs once β†’ server holds FAISS index in RAM
4
  """
5
 
6
- from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from typing import Dict, Any, Optional
9
- import math, contextlib, uuid
10
  from io import StringIO
11
 
12
  app = FastAPI(title="KG Embedding Server")
@@ -14,13 +14,18 @@ app = FastAPI(title="KG Embedding Server")
14
  # ══════════════════════════════════════════════════════
15
  # GLOBALS
16
  # ══════════════════════════════════════════════════════
17
- _model = None
18
- _use_st = False
19
- _sessions: Dict[str, Dict[str, Any]] = {} # session_id β†’ {nodes, index}
 
 
 
20
 
21
  @app.on_event("startup")
22
  def load_model():
23
- global _model, _use_st
 
 
24
  try:
25
  from sentence_transformers import SentenceTransformer
26
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
@@ -30,9 +35,27 @@ def load_model():
30
  except Exception as e:
31
  print(f"[Server] ST unavailable: {e}")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # ══════════════════════════════════════════════════════
34
- # MODELS
35
  # ══════════════════════════════════════════════════════
 
36
  class NodePayload(BaseModel):
37
  id: int
38
  title: str
@@ -62,6 +85,8 @@ class SearchResponse(BaseModel):
62
  class HealthResponse(BaseModel):
63
  status: str
64
  model_loaded: bool
 
 
65
 
66
  # ══════════════════════════════════════════════════════
67
  # ENDPOINTS
@@ -69,46 +94,58 @@ class HealthResponse(BaseModel):
69
 
70
  @app.get("/health", response_model=HealthResponse)
71
  def health():
72
- return {"status": "ok", "model_loaded": _use_st}
 
 
 
 
 
73
 
74
  @app.post("/sync", response_model=SyncResponse)
75
  def sync(req: SyncRequest):
76
- """Client uploads nodes once β†’ server builds FAISS index in RAM."""
77
- try:
78
- import faiss
79
- import numpy as np
80
- except ImportError:
81
- raise Exception("faiss not installed on server")
82
-
83
  if not _use_st or _model is None:
84
- return SyncResponse(status="no_model", count=0)
 
 
85
 
86
  texts = [f"{n.title} {n.content}" for n in req.nodes]
 
87
  if not texts:
88
  _sessions[req.session_id] = {"nodes": [], "index": None}
89
  return SyncResponse(status="ok", count=0)
90
 
 
91
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
92
  vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
93
 
94
- vecs = np.array(vecs, dtype="float32")
95
- dim = vecs.shape[1]
96
 
97
- index = faiss.IndexFlatIP(dim) # cosine sim (vectors already normalized)
 
98
  index.add(vecs)
99
 
100
  _sessions[req.session_id] = {
101
- "nodes": [n.dict() for n in req.nodes],
102
  "index": index,
103
  }
104
- print(f"[Server] Session {req.session_id[:8]}… β†’ {len(texts)} nodes indexed βœ“")
 
105
  return SyncResponse(status="ok", count=len(texts))
106
 
107
 
108
  @app.post("/search", response_model=SearchResponse)
109
  def search(req: SearchRequest):
110
- """Embed query only, FAISS search against cached index."""
111
- import numpy as np
 
 
 
 
112
 
113
  session = _sessions.get(req.session_id)
114
  if not session or session["index"] is None:
@@ -117,8 +154,9 @@ def search(req: SearchRequest):
117
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
118
  qvec = _model.encode([req.query], normalize_embeddings=True)
119
 
120
- qvec = np.array(qvec, dtype="float32")
121
  k = min(req.top_k, len(session["nodes"]))
 
122
  scores, indices = session["index"].search(qvec, k)
123
 
124
  results = []
@@ -126,11 +164,13 @@ def search(req: SearchRequest):
126
  if idx >= 0:
127
  results.append(SearchResult(
128
  node_id=session["nodes"][idx]["id"],
129
- score=float(score)
130
  ))
 
131
  return SearchResponse(results=results)
132
 
133
 
134
  @app.get("/ping")
135
  def ping():
 
136
  return {"pong": True}
 
3
  Session-based: client syncs once β†’ server holds FAISS index in RAM
4
  """
5
 
6
+ from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from typing import Dict, Any, Optional
9
+ import contextlib
10
  from io import StringIO
11
 
12
  app = FastAPI(title="KG Embedding Server")
 
14
  # ══════════════════════════════════════════════════════
15
  # GLOBALS
16
  # ══════════════════════════════════════════════════════
17
+ _model = None
18
+ _use_st = False
19
+ _faiss = None
20
+ _np = None
21
+ _sessions: Dict[str, Dict[str, Any]] = {}
22
+
23
 
24
  @app.on_event("startup")
25
  def load_model():
26
+ global _model, _use_st, _faiss, _np
27
+
28
+ # sentence-transformers
29
  try:
30
  from sentence_transformers import SentenceTransformer
31
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
 
35
  except Exception as e:
36
  print(f"[Server] ST unavailable: {e}")
37
 
38
+ # faiss
39
+ try:
40
+ import faiss as _faiss_mod
41
+ _faiss = _faiss_mod
42
+ print("[Server] faiss loaded βœ“")
43
+ except Exception as e:
44
+ print(f"[Server] faiss unavailable: {e}")
45
+
46
+ # numpy
47
+ try:
48
+ import numpy as np
49
+ _np = np
50
+ print("[Server] numpy loaded βœ“")
51
+ except Exception as e:
52
+ print(f"[Server] numpy unavailable: {e}")
53
+
54
+
55
  # ══════════════════════════════════════════════════════
56
+ # REQUEST / RESPONSE MODELS
57
  # ══════════════════════════════════════════════════════
58
+
59
  class NodePayload(BaseModel):
60
  id: int
61
  title: str
 
85
  class HealthResponse(BaseModel):
86
  status: str
87
  model_loaded: bool
88
+ faiss_loaded: bool
89
+
90
 
91
  # ══════════════════════════════════════════════════════
92
  # ENDPOINTS
 
94
 
95
  @app.get("/health", response_model=HealthResponse)
96
  def health():
97
+ return {
98
+ "status": "ok",
99
+ "model_loaded": _use_st,
100
+ "faiss_loaded": _faiss is not None,
101
+ }
102
+
103
 
104
  @app.post("/sync", response_model=SyncResponse)
105
  def sync(req: SyncRequest):
106
+ """
107
+ Client uploads all nodes once.
108
+ Server generates embeddings + builds FAISS index in RAM.
109
+ """
 
 
 
110
  if not _use_st or _model is None:
111
+ raise HTTPException(503, "sentence-transformers not loaded")
112
+ if _faiss is None or _np is None:
113
+ raise HTTPException(503, "faiss / numpy not loaded")
114
 
115
  texts = [f"{n.title} {n.content}" for n in req.nodes]
116
+
117
  if not texts:
118
  _sessions[req.session_id] = {"nodes": [], "index": None}
119
  return SyncResponse(status="ok", count=0)
120
 
121
+ # embed ΩƒΩ„ Ψ§Ω„Ω€ nodes دفعة واحدة Ω…ΨΉ Ψͺطبيع (cosine = dot product Ψ¨ΨΉΨ―ΩŠΩ†)
122
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
123
  vecs = _model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
124
 
125
+ vecs = _np.array(vecs, dtype="float32")
126
+ dim = vecs.shape[1]
127
 
128
+ # IndexFlatIP = exact inner product (= cosine Ψ¨ΨΉΨ― normalize)
129
+ index = _faiss.IndexFlatIP(dim)
130
  index.add(vecs)
131
 
132
  _sessions[req.session_id] = {
133
+ "nodes": [{"id": n.id, "title": n.title} for n in req.nodes],
134
  "index": index,
135
  }
136
+
137
+ print(f"[Server] /sync β†’ session={req.session_id[:8]}… | {len(texts)} nodes indexed βœ“")
138
  return SyncResponse(status="ok", count=len(texts))
139
 
140
 
141
  @app.post("/search", response_model=SearchResponse)
142
  def search(req: SearchRequest):
143
+ """
144
+ Embed query only β†’ FAISS ANN search against cached index.
145
+ Zero candidate transfer per search.
146
+ """
147
+ if not _use_st or _model is None:
148
+ raise HTTPException(503, "sentence-transformers not loaded")
149
 
150
  session = _sessions.get(req.session_id)
151
  if not session or session["index"] is None:
 
154
  with contextlib.redirect_stdout(StringIO()), contextlib.redirect_stderr(StringIO()):
155
  qvec = _model.encode([req.query], normalize_embeddings=True)
156
 
157
+ qvec = _np.array(qvec, dtype="float32")
158
  k = min(req.top_k, len(session["nodes"]))
159
+
160
  scores, indices = session["index"].search(qvec, k)
161
 
162
  results = []
 
164
  if idx >= 0:
165
  results.append(SearchResult(
166
  node_id=session["nodes"][idx]["id"],
167
+ score=float(score),
168
  ))
169
+
170
  return SearchResponse(results=results)
171
 
172
 
173
  @app.get("/ping")
174
  def ping():
175
+ """Keep-alive β€” Ψ§ΨΉΩ…Ω„ cron job يبعΨͺΩ‡ ΩƒΩ„ 4 Ψ―Ω‚Ψ§ΩŠΩ‚."""
176
  return {"pong": True}